From db28a110a161979f3ba6d330e09aca69ceb34b5c Mon Sep 17 00:00:00 2001 From: gopeshh Date: Mon, 21 Apr 2025 12:01:45 +0000 Subject: [PATCH 01/34] changes for basic LLaDA style diffusion masking support --- fast_llm/data/data/gpt/data.py | 37 ++++++++- fast_llm/data/dataset/gpt/config.py | 50 +++++++++++++ fast_llm/layers/__init__.py | 1 + fast_llm/layers/language_model/head.py | 58 ++++++++++++++ fast_llm/layers/transformer/config.py | 26 +++++++ fast_llm/layers/transformer/preprocessing.py | 79 +++++++++++++++++++- fast_llm/models/gpt/model.py | 64 ++++++++++------ 7 files changed, 287 insertions(+), 28 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 02c1b6c0..dfcb0e1e 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -35,15 +35,48 @@ class GPTBatch: def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: + """Collate function that supports LLaDA-style masking.""" stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + + token_ids = torch.from_numpy(stacked_ids) + + if sampling_parameters.diffusion.enabled: + batch_size, seq_len = token_ids.shape + device = token_ids.device + t = torch.rand(batch_size, device=device) + p_mask = (1 - sampling_parameters.diffusion.epsilon) * t + sampling_parameters.diffusion.epsilon + p_mask = torch.min(p_mask, torch.tensor(sampling_parameters.diffusion.max_mask_prob)) + p_mask = p_mask[:, None].expand(-1, seq_len) + + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + + if sampling_parameters.diffusion.pad_prob > 0: + pad_mask = torch.rand((batch_size,), device=device) < sampling_parameters.diffusion.pad_prob + if pad_mask.any(): + masked_indices[pad_mask] = True + + token_ids = torch.where(masked_indices, sampling_parameters.diffusion.mask_token_id, token_ids) + + if not stacked_spans: + stacked_spans = [] + stacked_spans.extend([ + torch.stack([masked_indices[i], p_mask[i]]) + for i in range(batch_size) + ]) + if sampling_parameters.use_loss_masking_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + if not stacked_spans: + stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=token_ids, + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..d84818ce 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -64,6 +64,51 @@ class GPTSamplingConfig(SamplingConfig): ) +@config_class() +class DiffusionMaskingConfig(Config): + """Configuration for diffusion-based masking during data preparation.""" + + enabled: bool = Field( + default=False, + desc="Whether to use diffusion-based masking during training", + hint=FieldHint.feature + ) + + epsilon: float = Field( + default=1e-3, + desc="Minimum masking probability", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0) + ) + + max_mask_prob: float = Field( + default=0.15, + desc="Maximum masking probability", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0) + ) + + pad_prob: float = Field( + default=0.01, + desc="Probability of padding tokens for 1% of samples", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0) + ) + + mask_token_id: int = Field( + default=103, + desc="Token ID to use for masking", + hint=FieldHint.required + ) + + def _validate(self) -> None: + super()._validate() + Assert.lt(self.epsilon, self.max_mask_prob, "epsilon must be less than max_mask_prob") + Assert.lt(self.max_mask_prob, 1.0, "max_mask_prob must be less than 1.0") + if self.enabled: + Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") + + @dataclasses.dataclass(kw_only=True) class GPTSamplingParameters(SamplingParameters): """ @@ -77,6 +122,11 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 + + # Diffusion masking configuration + diffusion: DiffusionMaskingConfig = dataclasses.field( + default_factory=DiffusionMaskingConfig + ) @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/layers/__init__.py b/fast_llm/layers/__init__.py index e69de29b..8b137891 100644 --- a/fast_llm/layers/__init__.py +++ b/fast_llm/layers/__init__.py @@ -0,0 +1 @@ + diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 3cc348d0..b84cc8f9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -298,3 +298,61 @@ def _logits_cross_entropy_forward_backward( # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) + +class MLMHead(LanguageModelHead): + """ + A masked language model head for diffusion-based training. + """ + + def _logits_cross_entropy_forward_backward( + self, + input_: torch.Tensor, + labels: torch.Tensor | None, + weight: torch.Tensor, + grad_output: float, + kwargs: dict, + losses: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + logits, context = output_parallel_linear_forward( + input_=input_, + weight=weight, + bias=None, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + ) + + if self._z_loss_factor > 0.0: + logits = z_loss( + logits, + self._z_loss_factor, + self.training, + grad_output, + losses, + LanguageModelLossNames.z_loss, + logits_scale_factor=self._logits_scale_factor, + ) + + if labels is None: + return logits * self._logits_scale_factor, None + + masked_indices = kwargs['masked_indices'] + p_mask = kwargs['p_mask'] + + masked_logits = logits[masked_indices] + masked_labels = labels[masked_indices] + masked_p = p_mask[masked_indices] + + # Compute MLM loss + loss, grad = cross_entropy_forward_backward( + masked_logits.flatten(0, -2), + masked_labels, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output / masked_p, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + ) + + loss = loss / (labels.shape[0] * labels.shape[1]) + + del logits + return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf409e77..17144132 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -469,11 +469,37 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) +@config_class() +class DiffusionMaskingConfig(Config): + """Configuration for diffusion-based masking in the transformer model. + This config only contains model-specific parameters. For masking parameters, + refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig.""" + + enabled: bool = Field( + default=False, + desc="Whether to use diffusion-based masking during training", + hint=FieldHint.feature + ) + bidirectional_attention: bool = Field( + default=True, + desc="Whether to use bidirectional attention for masked tokens", + hint=FieldHint.feature + ) + + def _validate(self) -> None: + super()._validate() + + @config_class() class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) peft: TransformerPeftConfig = FieldUpdate(default_factory=TransformerPeftConfig) + diffusion: DiffusionMaskingConfig = Field( + default_factory=DiffusionMaskingConfig, + desc="Configuration for diffusion-based masking", + hint=FieldHint.feature + ) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 2415a2f9..ba66cf3a 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -195,8 +195,8 @@ class BackupAttentionPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor + _mask: torch.Tensor | None + _mask_value: torch.Tensor | None _tensor_cache_max_sequence_length: int = -1 def __init__( @@ -204,6 +204,7 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + super().__init__() self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -263,7 +264,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), tensor_name=TransformerKwargs.attention_mask_value, - dtype=self._tensor_space.distributed_config.training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) @@ -337,3 +338,75 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ) kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + + +class LLaDAMaskingPreprocessor(Preprocessor): + """Preprocessor for LLaDA-style masking with diffusion-based training.""" + + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._scalar_dim = tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q) + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + """Apply LLaDA-style masking to the input sequence.""" + # Get diffusion config from dataset parameters + diffusion_config = kwargs['parameters'].diffusion + if not diffusion_config.enabled: + return + + batch_size, seq_len = batch.shape + device = batch.device + + t = torch.rand(batch_size, device=device) + + p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon + p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) + p_mask = p_mask[:, None].expand(-1, seq_len) + + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + + if diffusion_config.pad_prob > 0: + pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob + if pad_mask.any(): + masked_indices[pad_mask] = True + + kwargs['masked_indices'] = masked_indices + kwargs['p_mask'] = p_mask + + if self._config.diffusion.bidirectional_attention: + # Bidirectional attention - all tokens can attend to all other tokens + attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) + else: + # Causal attention + attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_() + + + kwargs[TransformerKwargs.attention_mask] = attention_mask + kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(-10000.0, device=device) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + """Define tensor metadata for masking tensors.""" + # Get diffusion config from dataset parameters + diffusion_config = kwargs['parameters'].diffusion + if not diffusion_config.enabled: + return + + kwargs['masked_indices'] = TensorMeta.from_dims( + (self._scalar_dim, self._sequence_dim), + tensor_name='masked_indices' + ) + kwargs['p_mask'] = TensorMeta.from_dims( + (self._scalar_dim, self._sequence_dim), + tensor_name='p_mask' + ) + kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + (self._scalar_dim, self._scalar_dim, self._sequence_dim, self._sequence_dim), + tensor_name=TransformerKwargs.attention_mask + ) + kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + (self._scalar_dim,), + tensor_name=TransformerKwargs.attention_mask_value + ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a7ec58d6..0f1453d4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -13,7 +13,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead, MLMHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, @@ -85,7 +85,11 @@ def get_output_layers(self) -> list[Layer]: # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, ), - LanguageModelHead( + MLMHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) if self._config.transformer.diffusion.enabled else LanguageModelHead( self._config, self._tensor_space, prediction_distance=i, @@ -332,38 +336,52 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: @property def loss_defs(self) -> list[LossDef]: loss_defs = [] - if ( - self._config.transformer.num_experts > 1 - and self._config.transformer.expert_routing_type == RoutingType.topk - ): + if self._config.transformer.diffusion.enabled: + # MLM loss for LLaDA-style training loss_defs.append( LossDef( - name=TransformerLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + name=LanguageModelLossNames.mlm_loss, + formatted_name="MLM Loss", + count=1, + dtype=torch.float32 ) ) - if self._config.transformer.expert_z_loss_coefficient: - loss_defs.append( - LossDef( - name=TransformerLossNames.router_z_loss, - formatted_name="router z loss", - count=self._config.transformer.num_layers, - ) - ) - if self._config.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) - - for i in range(self._config.prediction_heads): + else: + # Standard language modeling loss loss_defs.append( LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", + name=LanguageModelLossNames.language_model_loss, + formatted_name="Language Model Loss", count=1, + dtype=torch.float32 ) ) + + if self._config.transformer.num_experts > 1: + if self._config.transformer.expert_routing_type == RoutingType.topk: + loss_defs.append( + LossDef( + name=TransformerLossNames.load_balancing_loss, + formatted_name="Load Balancing Loss", + count=1, + dtype=torch.float32 + ) + ) + if self._config.transformer.expert_z_loss_coefficient > 0: + loss_defs.append( + LossDef( + name=TransformerLossNames.router_z_loss, + formatted_name="Router Z Loss", + count=1, + dtype=torch.float32 + ) + ) return loss_defs + def forward(self, input_ids: torch.Tensor, kwargs: dict[str, typing.Any]) -> dict[str, torch.Tensor]: + outputs = super().forward(input_ids, kwargs) + return outputs + def add_preprocessor(self, preprocessor: Preprocessor): assert not self._is_setup self._preprocessors.append(preprocessor) From 3d44671cf7e85601a9ebcadbd23ca54ca5ba3e15 Mon Sep 17 00:00:00 2001 From: gopeshh Date: Tue, 22 Apr 2025 19:16:19 +0000 Subject: [PATCH 02/34] tests for masking and MLM loss --- tests/test_masking.py | 184 +++++++++++++++++++++++++++++++++++++++++ tests/test_mlm_loss.py | 149 +++++++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 tests/test_masking.py create mode 100644 tests/test_mlm_loss.py diff --git a/tests/test_masking.py b/tests/test_masking.py new file mode 100644 index 00000000..30ddd69d --- /dev/null +++ b/tests/test_masking.py @@ -0,0 +1,184 @@ +import pytest +import torch + +from fast_llm.layers.language_model.preprocessing import LLaDAMaskingPreprocessor +from fast_llm.layers.transformer.config import DiffusionMaskingConfig + + +@pytest.fixture +def masking_config(): + return DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + +def test_masking_basic(): + config = DiffusionMaskingConfig( + enabled=True, + epsilon=0.15, # 15% minimum masking + max_mask_prob=0.5, # 50% maximum masking + pad_prob=0.1, + mask_token_id=103 + ) + + preprocessor = LLaDAMaskingPreprocessor(config) + + batch_size = 4 + seq_len = 10 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + + input_ids[:, -2:] = 0 # Add padding at the end + + outputs = preprocessor(input_ids) + + masked_indices = outputs['masked_indices'] + p_mask = outputs['p_mask'] + masked_input = outputs['input_ids'] + + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + assert masked_input.shape == input_ids.shape + + padding_positions = (input_ids == 0) + assert not masked_indices[padding_positions].any() + assert (p_mask[padding_positions] == 0).all() + + non_pad_positions = ~padding_positions + assert (p_mask[non_pad_positions] >= config.epsilon).all() + assert (p_mask[non_pad_positions] <= config.max_mask_prob).all() + + assert (masked_input[masked_indices] == config.mask_token_id).all() + + unmasked_positions = ~masked_indices & non_pad_positions + assert (masked_input[unmasked_positions] == input_ids[unmasked_positions]).all() + + +def test_masking_edge_cases(): + config = DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + preprocessor = LLaDAMaskingPreprocessor(config) + + input_ids = torch.randint(0, 1000, (1, 5)) + outputs = preprocessor(input_ids) + assert outputs['masked_indices'].shape == (1, 5) + assert outputs['p_mask'].shape == (1, 5) + + input_ids = torch.zeros(2, 4) + outputs = preprocessor(input_ids) + assert not outputs['masked_indices'].any() # No tokens should be masked + assert (outputs['p_mask'] == 0).all() # All masking probs should be 0 + + input_ids = torch.randint(1, 1000, (2, 4)) # All tokens are non-padding + outputs = preprocessor(input_ids) + assert outputs['masked_indices'].any() # Some tokens should be masked + assert (outputs['p_mask'] >= config.epsilon).all() # All probs should be >= epsilon + + input_ids = torch.randint(1, 1000, (1, 1)) + outputs = preprocessor(input_ids) + assert outputs['masked_indices'].shape == (1, 1) + assert outputs['p_mask'].shape == (1, 1) + + +def test_masking_probabilities(): + config = DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + preprocessor = LLaDAMaskingPreprocessor(config) + + input_ids = torch.ones(3, 8) + input_ids[0, :] = torch.arange(1, 9) # Increasing sequence + input_ids[1, :] = torch.arange(8, 0, -1) # Decreasing sequence + input_ids[2, :] = 1 # Constant sequence + + n_trials = 100 + mask_counts = torch.zeros_like(input_ids) + + for _ in range(n_trials): + outputs = preprocessor(input_ids) + mask_counts += outputs['masked_indices'].float() + + empirical_probs = mask_counts / n_trials + + assert (empirical_probs >= config.epsilon - 0.05).all() # Allow small deviation + assert (empirical_probs <= config.max_mask_prob + 0.05).all() + + +def test_masking_deterministic(): + config = DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + preprocessor = LLaDAMaskingPreprocessor(config) + + + torch.manual_seed(42) + + + input_ids = torch.randint(1, 1000, (2, 6)) + + torch.manual_seed(42) + outputs1 = preprocessor(input_ids) + + torch.manual_seed(42) + outputs2 = preprocessor(input_ids) + + assert torch.equal(outputs1['masked_indices'], outputs2['masked_indices']) + assert torch.equal(outputs1['p_mask'], outputs2['p_mask']) + assert torch.equal(outputs1['input_ids'], outputs2['input_ids']) + + +def test_masking_config_validation(): + with pytest.raises(ValueError): + DiffusionMaskingConfig( + enabled=True, + epsilon=-0.1, # Invalid negative value + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + with pytest.raises(ValueError): + DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=1.5, # Invalid value > 1 + pad_prob=0.1, + mask_token_id=103 + ) + + with pytest.raises(ValueError): + DiffusionMaskingConfig( + enabled=True, + epsilon=0.6, # Greater than max_mask_prob + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + + with pytest.raises(ValueError): + DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=-0.1, # Invalid negative value + mask_token_id=103 + ) \ No newline at end of file diff --git a/tests/test_mlm_loss.py b/tests/test_mlm_loss.py new file mode 100644 index 00000000..f7913c63 --- /dev/null +++ b/tests/test_mlm_loss.py @@ -0,0 +1,149 @@ +import pytest +import torch +import torch.nn.functional as F + +from fast_llm.layers.language_model.head import MLMHead +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace, DefaultDimNames, TensorDim +from fast_llm.layers.transformer.config import TransformerConfig, DiffusionMaskingConfig +from fast_llm.engine.distributed.config import DistributedConfig + + +@pytest.fixture +def mlm_config(): + transformer_config = TransformerConfig( + hidden_size=768, + num_layers=12, + num_attention_heads=12, + diffusion=DiffusionMaskingConfig( + enabled=True, + epsilon=0.1, + max_mask_prob=0.5, + pad_prob=0.1, + mask_token_id=103 + ) + ) + + return LanguageModelBaseConfig( + vocab_size=30522, + transformer=transformer_config, + tie_word_embeddings=False, + parallel_embeddings=False, + prediction_heads=1 + ) + + +@pytest.fixture +def tensor_space(): + distributed_config = DistributedConfig() + tensor_space = TensorSpace(distributed_config) + tensor_space.add_tensor_dim(DefaultDimNames.scalar, 1) + tensor_space.add_tensor_dim("hidden", 768) + tensor_space.add_tensor_dim("vocab", 30522) + return tensor_space + + +def test_mlm_loss_computation(mlm_config, tensor_space): + mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) + + batch_size = 4 + seq_len = 8 + hidden_size = 768 + vocab_size = 30522 + + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + masked_indices = torch.zeros(batch_size, seq_len, dtype=torch.bool) + masked_indices[:, [2, 5]] = True # Mask positions 2 and 5 in each sequence + + p_mask = torch.full((batch_size, seq_len), 0.15) # 15% masking probability + + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + kwargs = { + 'masked_indices': masked_indices, + 'p_mask': p_mask, + 'labels': labels + } + + losses = {} + output = mlm_head(hidden_states, kwargs, losses) + + + assert output is not None + assert isinstance(output, torch.Tensor) + assert output.requires_grad + + + assert losses # losses dictionary should not be empty + + # Test with no masked positions + kwargs['masked_indices'] = torch.zeros_like(masked_indices) + losses = {} + output_no_masks = mlm_head(hidden_states, kwargs, losses) + assert output_no_masks is not None + + # Test with all positions masked + kwargs['masked_indices'] = torch.ones_like(masked_indices) + losses = {} + output_all_masked = mlm_head(hidden_states, kwargs, losses) + assert output_all_masked is not None + + +def test_mlm_loss_edge_cases(mlm_config, tensor_space): + mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) + + hidden_states = torch.randn(1, 4, 768) + masked_indices = torch.zeros(1, 4, dtype=torch.bool) + masked_indices[0, 1] = True + p_mask = torch.full((1, 4), 0.15) + labels = torch.randint(0, 30522, (1, 4)) + + kwargs = { + 'masked_indices': masked_indices, + 'p_mask': p_mask, + 'labels': labels + } + + losses = {} + output = mlm_head(hidden_states, kwargs, losses) + assert output is not None + + p_mask = torch.full((1, 4), 0.01) + kwargs['p_mask'] = p_mask + losses = {} + output = mlm_head(hidden_states, kwargs, losses) + assert output is not None + + p_mask = torch.full((1, 4), 0.5) # max_mask_prob from config + kwargs['p_mask'] = p_mask + losses = {} + output = mlm_head(hidden_states, kwargs, losses) + assert output is not None + + +def test_mlm_loss_backward(mlm_config, tensor_space): + mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) + + hidden_states = torch.randn(2, 6, 768, requires_grad=True) + masked_indices = torch.zeros(2, 6, dtype=torch.bool) + masked_indices[:, [1, 4]] = True + p_mask = torch.full((2, 6), 0.15) + labels = torch.randint(0, 30522, (2, 6)) + + kwargs = { + 'masked_indices': masked_indices, + 'p_mask': p_mask, + 'labels': labels + } + + losses = {} + output = mlm_head(hidden_states, kwargs, losses) + + output.backward() + + assert hidden_states.grad is not None + assert not torch.isnan(hidden_states.grad).any() + assert not torch.isinf(hidden_states.grad).any() + + assert hidden_states.grad.shape == hidden_states.shape \ No newline at end of file From 46dd535cc98a09c284b1aa7e89b8aa9a91acd8d5 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 4 Jun 2025 18:15:27 +0000 Subject: [PATCH 03/34] temp fixes --- fast_llm/data/dataset/gpt/config.py | 43 ++++++++++----------------- fast_llm/layers/transformer/config.py | 14 ++++----- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index d84818ce..549561dd 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -67,44 +67,35 @@ class GPTSamplingConfig(SamplingConfig): @config_class() class DiffusionMaskingConfig(Config): """Configuration for diffusion-based masking during data preparation.""" - + enabled: bool = Field( - default=False, - desc="Whether to use diffusion-based masking during training", - hint=FieldHint.feature + default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature ) - + epsilon: float = Field( - default=1e-3, - desc="Minimum masking probability", - hint=FieldHint.performance, - valid=check_field(Assert.gt, 0) + default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) ) - + max_mask_prob: float = Field( - default=0.15, - desc="Maximum masking probability", - hint=FieldHint.performance, - valid=check_field(Assert.gt, 0) + default=0.15, desc="Maximum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) ) pad_prob: float = Field( default=0.01, desc="Probability of padding tokens for 1% of samples", hint=FieldHint.optional, - valid=check_field(Assert.geq, 0) - ) - - mask_token_id: int = Field( - default=103, - desc="Token ID to use for masking", - hint=FieldHint.required + valid=check_field(Assert.geq, 0), ) + mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional) + def _validate(self) -> None: super()._validate() - Assert.lt(self.epsilon, self.max_mask_prob, "epsilon must be less than max_mask_prob") - Assert.lt(self.max_mask_prob, 1.0, "max_mask_prob must be less than 1.0") + Assert.lt(self.epsilon, self.max_mask_prob) # , "epsilon must be less than max_mask_prob") + Assert.lt( + self.max_mask_prob, + 1.0, + ) # "max_mask_prob must be less than 1.0") if self.enabled: Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") @@ -122,11 +113,9 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 - + # Diffusion masking configuration - diffusion: DiffusionMaskingConfig = dataclasses.field( - default_factory=DiffusionMaskingConfig - ) + diffusion: DiffusionMaskingConfig = dataclasses.field(default_factory=DiffusionMaskingConfig) @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 17144132..3fe33113 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -470,20 +470,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @config_class() -class DiffusionMaskingConfig(Config): +class DiffusionMaskingConfig(TransformerArchitectureConfig): """Configuration for diffusion-based masking in the transformer model. This config only contains model-specific parameters. For masking parameters, refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig.""" - + enabled: bool = Field( - default=False, - desc="Whether to use diffusion-based masking during training", - hint=FieldHint.feature + default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature ) bidirectional_attention: bool = Field( - default=True, - desc="Whether to use bidirectional attention for masked tokens", - hint=FieldHint.feature + default=True, desc="Whether to use bidirectional attention for masked tokens", hint=FieldHint.feature ) def _validate(self) -> None: @@ -498,7 +494,7 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): diffusion: DiffusionMaskingConfig = Field( default_factory=DiffusionMaskingConfig, desc="Configuration for diffusion-based masking", - hint=FieldHint.feature + hint=FieldHint.feature, ) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) From aa8ab4ddef2f160b2f53d8648af9ff286035325e Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 4 Jun 2025 21:29:09 +0000 Subject: [PATCH 04/34] tmp fix --- fast_llm/models/gpt/model.py | 39 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0f1453d4..24d5af4b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -85,14 +85,18 @@ def get_output_layers(self) -> list[Layer]: # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, ), - MLMHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) if self._config.transformer.diffusion.enabled else LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, + ( + MLMHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) + if self._config.transformer.bidirectional_attention + else LanguageModelHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) ), ] ] @@ -333,18 +337,13 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - @property + # @property def loss_defs(self) -> list[LossDef]: loss_defs = [] - if self._config.transformer.diffusion.enabled: + if self._config.transformer.bidirectional_attention: # MLM loss for LLaDA-style training loss_defs.append( - LossDef( - name=LanguageModelLossNames.mlm_loss, - formatted_name="MLM Loss", - count=1, - dtype=torch.float32 - ) + LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1, dtype=torch.float32) ) else: # Standard language modeling loss @@ -353,10 +352,10 @@ def loss_defs(self) -> list[LossDef]: name=LanguageModelLossNames.language_model_loss, formatted_name="Language Model Loss", count=1, - dtype=torch.float32 + dtype=torch.float32, ) ) - + if self._config.transformer.num_experts > 1: if self._config.transformer.expert_routing_type == RoutingType.topk: loss_defs.append( @@ -364,7 +363,7 @@ def loss_defs(self) -> list[LossDef]: name=TransformerLossNames.load_balancing_loss, formatted_name="Load Balancing Loss", count=1, - dtype=torch.float32 + dtype=torch.float32, ) ) if self._config.transformer.expert_z_loss_coefficient > 0: @@ -373,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: name=TransformerLossNames.router_z_loss, formatted_name="Router Z Loss", count=1, - dtype=torch.float32 + dtype=torch.float32, ) ) return loss_defs From 9f348e773fcd794ff0b3cf4b11ce5d8a1bf80123 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sat, 7 Jun 2025 01:51:17 +0000 Subject: [PATCH 05/34] including masked diffusion training setup --- fast_llm/data/data/gpt/data.py | 77 ++++++----- fast_llm/data/dataset/gpt/config.py | 61 +++++---- fast_llm/engine/base_model/base_model.py | 4 +- fast_llm/engine/schedule/runner.py | 3 +- fast_llm/engine/training/trainer.py | 2 +- fast_llm/layers/language_model/config.py | 3 + fast_llm/layers/language_model/head.py | 112 ++++++++++++++-- fast_llm/layers/transformer/config.py | 91 ++++++++++--- fast_llm/layers/transformer/preprocessing.py | 129 +++++++++++-------- fast_llm/models/custom/model.py | 4 +- fast_llm/models/gpt/model.py | 36 +++++- fast_llm/models/gpt/trainer.py | 1 + 12 files changed, 373 insertions(+), 150 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index dfcb0e1e..72218ea4 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,51 +32,68 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + mask_indexes: torch.Tensor | None = None + mask_probabilities: torch.Tensor | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - """Collate function that supports LLaDA-style masking.""" + stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None - + mask_indexes = None + mask_probabilities = None + token_ids = torch.from_numpy(stacked_ids) - + if sampling_parameters.diffusion.enabled: + + diffusion_config = sampling_parameters.diffusion + batch_size, seq_len = token_ids.shape - device = token_ids.device - t = torch.rand(batch_size, device=device) - p_mask = (1 - sampling_parameters.diffusion.epsilon) * t + sampling_parameters.diffusion.epsilon - p_mask = torch.min(p_mask, torch.tensor(sampling_parameters.diffusion.max_mask_prob)) - p_mask = p_mask[:, None].expand(-1, seq_len) - - masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask - - if sampling_parameters.diffusion.pad_prob > 0: - pad_mask = torch.rand((batch_size,), device=device) < sampling_parameters.diffusion.pad_prob - if pad_mask.any(): - masked_indices[pad_mask] = True - - token_ids = torch.where(masked_indices, sampling_parameters.diffusion.mask_token_id, token_ids) - - if not stacked_spans: - stacked_spans = [] - stacked_spans.extend([ - torch.stack([masked_indices[i], p_mask[i]]) - for i in range(batch_size) - ]) - + mask_token_id = diffusion_config.mask_token_id + + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + + # Compute the mask probabilities for every sequence in the batch + p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon + + # Do we need to clamp at max_mask_prob? + # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) + + # Repeat the same mask probability for each token in the sequence + mask_probabilities = p_mask[:, None].repeat(1, seq_len) + # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] + # We should not mask it + mask_probabilities[:, 0] = 0.0 + # print(f"2 p_mask: {mask_probabilities} {mask_probabilities.shape}") + + # Generate random values for all tokens in the batch and only mask the positions\ + # where the value is smaller than the mask probability + mask_indexes = torch.rand((batch_size, seq_len)) < mask_probabilities + + # Need further classification of this padding - 1% data to have partial sequences and padding + # if diffusion_config.pad_prob > 0: + # pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob + # if pad_mask.any(): + # mask_indexes[pad_mask] = True + + # Replace masked tokens with the mask token ID to create input for the model. + token_ids = torch.where(mask_indexes, mask_token_id, token_ids) + if sampling_parameters.use_loss_masking_spans: - if not stacked_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - + stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] - + return GPTBatch( token_ids=token_ids, loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths + sequence_lengths=sequence_lengths, + mask_indexes=mask_indexes, + mask_probabilities=mask_probabilities, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 549561dd..c5979194 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -44,32 +44,12 @@ class ShufflingType(str, enum.Enum): legacy = "legacy" -@config_class() -class GPTSamplingConfig(SamplingConfig): - """ - A dataset-dependent configuration for sampling. - """ - - gpu: bool = Field( - default=True, - desc="Enable fast sampling on GPU." - " Note that random sampling works differently on GPU," - " so the sample won't match the CPU equivalent.", - hint=FieldHint.feature, - ) - shuffle: ShufflingType = Field( - default=ShufflingType.epoch, - desc="Shuffling strategy.", - hint=FieldHint.feature, - ) - - @config_class() class DiffusionMaskingConfig(Config): """Configuration for diffusion-based masking during data preparation.""" enabled: bool = Field( - default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature + default=False, desc="Whether to use masked diffusion during training", hint=FieldHint.feature ) epsilon: float = Field( @@ -95,9 +75,34 @@ def _validate(self) -> None: Assert.lt( self.max_mask_prob, 1.0, - ) # "max_mask_prob must be less than 1.0") - if self.enabled: - Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") + ) # "max_mask_prob must be less than 1.0") + # if self.enabled: + # Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") + + +@config_class() +class GPTSamplingConfig(SamplingConfig): + """ + A dataset-dependent configuration for sampling. + """ + + gpu: bool = Field( + default=True, + desc="Enable fast sampling on GPU." + " Note that random sampling works differently on GPU," + " so the sample won't match the CPU equivalent.", + hint=FieldHint.feature, + ) + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, + desc="Shuffling strategy.", + hint=FieldHint.feature, + ) + diffusion: DiffusionMaskingConfig = Field( + default_factory=DiffusionMaskingConfig, + desc="Configuration for diffusion-based masking during data preparation.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -113,9 +118,11 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 - - # Diffusion masking configuration - diffusion: DiffusionMaskingConfig = dataclasses.field(default_factory=DiffusionMaskingConfig) + diffusion: DiffusionMaskingConfig = Field( + default_factory=DiffusionMaskingConfig, + desc="Configuration for diffusion-based masking during data preparation. Will be copied from GPTSamplingConfig during ", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 76da0f9b..d7de7cd3 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -131,9 +131,9 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - @property + # @property @abc.abstractmethod - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: pass def add_preprocessor(self, preprocessor: Preprocessor): diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559..f5f0e111 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,8 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + print(f"base_model={type(self._multi_stage.base_model)} {self._multi_stage.base_model.get_loss_defs()} ") + self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 66f1ad86..5685729d 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -88,7 +88,7 @@ def __init__(self, config: TrainerConfig): # Prune empty phases. self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() # Setup the schedules self._schedule = { diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 387fa7ad..19010ba4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -21,6 +21,7 @@ class LanguageModelDimNames: class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" + mlm_loss = "masked_language_model_loss" @staticmethod def multi_token_prediction_loss(index: int) -> str: @@ -34,6 +35,8 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + mask_indexes = "mask_indexes" + mask_probabilities = "mask_probabilities" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b84cc8f9..0fddd9d8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -153,6 +153,7 @@ def _forward_backward( ] ) labels = labels.flatten() + if self._sequence_parallel_logits: labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) do_grad = labels is not None and self.training @@ -299,11 +300,62 @@ def _logits_cross_entropy_forward_backward( del logits return loss, output_parallel_linear_backward(grad, context) + class MLMHead(LanguageModelHead): """ - A masked language model head for diffusion-based training. + A masked language model head for diffusion-based training.` """ - + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + prediction_distance: int, + ): + super().__init__(config, tensor_space, prediction_distance) + self._loss_name = LanguageModelLossNames.mlm_loss + + # def forward( + # self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None + # ) -> torch.Tensor: + # if isinstance(input_, TensorMeta): + # return TensorMeta.from_tensor_space( + # (DefaultDimNames.scalar,), + # self._tensor_space, + # tensor_name="Loss", + # reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + # ) + + # # Dropping MTP Stuff + # # if not self.is_last_head: + # # # MTP: split the stacked input + # # shared_hidden, input_ = torch.unbind(input_, dim=0) + + # # TODO: Pytorch copies the grads in backward for no reason (not sure if still the case) + # # TODO: Torch compile implementation sometimes break. + # # TODO: Double-check correctness, optimize a bit more. + # # TODO: Drop autograd entirely. + # # TODO: Skip cross-entropy backward if not needed. + + # print(f"forward input_: {input_.shape} {input_}") + + # # Input needs to be the masked input with masked tokens + # input_ = kwargs["noisy_batch"] + + # language_model_loss = self._forward(input_, kwargs, losses) + # if language_model_loss is not None: + # losses[self._loss_name].append(language_model_loss) + # # TODO: Return the model output when needed. + # # if self.is_last_head: + # # # Last head should return the loss for backward. + # return language_model_loss + # # else: + # # if self.training: + # # # Backward hook to compute the gradient of the loss + # # shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + # # # MTP: Return shared_hidden to be used by the next head. + # # return shared_hidden + def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, @@ -313,6 +365,16 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + + # can we just do this here? + # Input needs to be the masked input with masked tokens + # instead of forward pass + # this seems only a liner layer we need other layers too + # input_ = kwargs["noisy_batch"] + + # print(f"input_: {input_.shape} {input_}") + # print(f"labels: {labels.shape} {labels}") + logits, context = output_parallel_linear_forward( input_=input_, weight=weight, @@ -321,6 +383,8 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._parallel_embeddings, ) + # print(f"logits {logits.shape} {logits}") + if self._z_loss_factor > 0.0: logits = z_loss( logits, @@ -335,24 +399,48 @@ def _logits_cross_entropy_forward_backward( if labels is None: return logits * self._logits_scale_factor, None - masked_indices = kwargs['masked_indices'] - p_mask = kwargs['p_mask'] - - masked_logits = logits[masked_indices] - masked_labels = labels[masked_indices] - masked_p = p_mask[masked_indices] - + masked_indices = kwargs[LanguageModelKwargs.mask_indexes] + p_mask = kwargs[LanguageModelKwargs.mask_probabilities] + # print(f"masked_indices: {masked_indices.shape} {masked_indices}") + + # The labels are already left shifted x = [0, 1, 2, 3, 4, 5] -> [1, 2, 3, 4, 5?] + # then the mask index on the label will give the correct tokens one-hot vector? + # this happens in model.py part of preprocessing + + # Question pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, + # can it just learn to copy 3? i.e copy the next token to the masked? + + # TODO: update loss only on masked tokens error from earlier linear layer missing tokens becuz of masking?? + + masked_logits = ( + logits # logits[masked_indices[:, 1:]] # Skip the first token, which is the shift/context token + ) + # flatten the masked indices to match the logits + masked_indices_flt = masked_indices[:, 1:].flatten() + # print(f"masked_indices: {masked_indices_flt.shape} {masked_indices_flt}") + masked_labels = labels # labels[masked_indices_flt] + # print(f"p_mask {p_mask.shape} {masked_indices.shape}") + p_mask[masked_indices] + # Compute MLM loss loss, grad = cross_entropy_forward_backward( masked_logits.flatten(0, -2), masked_labels, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output / masked_p, + grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, ) - loss = loss / (labels.shape[0] * labels.shape[1]) - + # print(f"loss: {loss.shape} {loss}") + # Revisit this with the formula and what happens inside the cross_entropy_forward_backward + # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 + # loss = loss / masked_p + # print(f"loss: {loss.shape} {loss}") + + # revisit this with the formula and what happens inside the cross_entropy_forward_backward + # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275 + # loss = loss.sum() / (labels.shape[0] * labels.shape[1]) + del logits return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 3fe33113..58528376 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -469,21 +469,66 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) -@config_class() -class DiffusionMaskingConfig(TransformerArchitectureConfig): - """Configuration for diffusion-based masking in the transformer model. - This config only contains model-specific parameters. For masking parameters, - refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig.""" - - enabled: bool = Field( - default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature - ) - bidirectional_attention: bool = Field( - default=True, desc="Whether to use bidirectional attention for masked tokens", hint=FieldHint.feature - ) - - def _validate(self) -> None: - super()._validate() +# @config_class() +# class DiffusionMaskingConfig(TransformerArchitectureConfig): +# """Configuration for diffusion-based masking in the transformer model. +# This config only contains model-specific parameters. For masking parameters, +# refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig.""" + +# enabled: bool = Field( +# default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature +# ) +# bidirectional_attention: bool = Field( +# default=True, desc="Whether to use bidirectional attention for masked tokens", hint=FieldHint.feature +# ) + +# def _validate(self) -> None: +# super()._validate() + + +# @config_class() +# class DiffusionMaskingConfig(TransformerArchitectureConfig): +# """Configuration for diffusion-based masking during data preparation.""" + +# enabled: bool = Field( +# default=False, +# desc="Whether to use masked diffusion during training", +# hint=FieldHint.feature +# ) + +# epsilon: float = Field( +# default=1e-3, +# desc="Minimum masking probability", +# hint=FieldHint.performance, +# valid=check_field(Assert.gt, 0) +# ) + +# max_mask_prob: float = Field( +# default=0.15, +# desc="Maximum masking probability", +# hint=FieldHint.performance, +# valid=check_field(Assert.gt, 0) +# ) + +# pad_prob: float = Field( +# default=0.01, +# desc="Probability of padding tokens for 1% of samples", +# hint=FieldHint.optional, +# valid=check_field(Assert.geq, 0) +# ) + +# mask_token_id: int = Field( +# default=103, +# desc="Token ID to use for masking", +# hint=FieldHint.optional +# ) + +# def _validate(self) -> None: +# super()._validate() +# Assert.lt(self.epsilon, self.max_mask_prob) #, "epsilon must be less than max_mask_prob") +# Assert.lt(self.max_mask_prob, 1.0,) # "max_mask_prob must be less than 1.0") +# # if self.enabled: +# # Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") @config_class() @@ -491,11 +536,7 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) peft: TransformerPeftConfig = FieldUpdate(default_factory=TransformerPeftConfig) - diffusion: DiffusionMaskingConfig = Field( - default_factory=DiffusionMaskingConfig, - desc="Configuration for diffusion-based masking", - hint=FieldHint.feature, - ) + # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( @@ -688,6 +729,16 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + diffusion: bool = Field( + default=False, + desc="Use masked-diffusion for training.", + hint=FieldHint.feature, + ) + # diffusion: DiffusionMaskingConfig = Field( + # default_factory=DiffusionMaskingConfig, + # desc="Configuration for masked diffusion training.", + # hint=FieldHint.feature, + # ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index ba66cf3a..0279abaf 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -247,6 +247,8 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) + + # can we add a bidirectional attention here? kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: @@ -342,71 +344,96 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: class LLaDAMaskingPreprocessor(Preprocessor): """Preprocessor for LLaDA-style masking with diffusion-based training.""" - + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config + # print(f"tensor_space: {tensor_space._tensor_dims.keys()}") self._scalar_dim = tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q) - + # self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q) + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """Apply LLaDA-style masking to the input sequence.""" # Get diffusion config from dataset parameters - diffusion_config = kwargs['parameters'].diffusion + # print(f"kwargs: {kwargs.keys()}") + + print(f"1 batch: {type(batch)} {batch.shape}") + + diffusion_config = self._config.diffusion if not diffusion_config.enabled: return - + batch_size, seq_len = batch.shape device = batch.device - - t = torch.rand(batch_size, device=device) - + mask_token_id = diffusion_config.mask_token_id + + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,), device=device) + + # Compute the mask probabilities for every sequence in the batch p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon - p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) - p_mask = p_mask[:, None].expand(-1, seq_len) - + + # Do we need to clamp at max_mask_prob? + # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) + + # Repeat the same mask probability for each token in the sequence + p_mask = p_mask[:, None].repeat(1, seq_len) + print(f"2 p_mask: {p_mask} {p_mask.shape}") + + # Generate random values for all tokens in the batch and only mask the positions\ + # where the value is smaller than the mask probability masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask - - if diffusion_config.pad_prob > 0: - pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob - if pad_mask.any(): - masked_indices[pad_mask] = True - - kwargs['masked_indices'] = masked_indices - kwargs['p_mask'] = p_mask - - if self._config.diffusion.bidirectional_attention: - # Bidirectional attention - all tokens can attend to all other tokens - attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) - else: - # Causal attention - attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_() - + # Need further classification of this padding - 1% data to have partial sequences and padding + # if diffusion_config.pad_prob > 0: + # pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob + # if pad_mask.any(): + # masked_indices[pad_mask] = True + + # Replace masked tokens with the mask token ID to create input for the model. + noisy_batch = torch.where(masked_indices, mask_token_id, batch) + + kwargs["masked_indices"] = masked_indices + kwargs["p_mask"] = p_mask + kwargs["noisy_batch"] = noisy_batch + + # Bidirectional attention - all tokens can attend to all other tokens + attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) + + # if self._config.bidirectional_attention: + # # Bidirectional attention - all tokens can attend to all other tokens + # attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) + # else: + # # Causal attention + # attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_() + kwargs[TransformerKwargs.attention_mask] = attention_mask kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(-10000.0, device=device) - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - """Define tensor metadata for masking tensors.""" - # Get diffusion config from dataset parameters - diffusion_config = kwargs['parameters'].diffusion - if not diffusion_config.enabled: - return - - kwargs['masked_indices'] = TensorMeta.from_dims( - (self._scalar_dim, self._sequence_dim), - tensor_name='masked_indices' - ) - kwargs['p_mask'] = TensorMeta.from_dims( - (self._scalar_dim, self._sequence_dim), - tensor_name='p_mask' - ) - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( - (self._scalar_dim, self._scalar_dim, self._sequence_dim, self._sequence_dim), - tensor_name=TransformerKwargs.attention_mask - ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( - (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value - ) + + # def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # """Define tensor metadata for masking tensors.""" + + # print(f"kwargs: {kwargs.keys()}") + # # Get diffusion config from dataset parameters + # sequence_q_dim = kwargs[TransformerKwargs.sequence_q_dim].size + # diffusion_config = kwargs['parameters'].diffusion + # if not diffusion_config.enabled: + # return + + # kwargs['masked_indices'] = TensorMeta.from_dims( + # (self._scalar_dim, sequence_q_dim), + # tensor_name='masked_indices' + # ) + # kwargs['p_mask'] = TensorMeta.from_dims( + # (self._scalar_dim, sequence_q_dim), + # tensor_name='p_mask' + # ) + # kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + # (self._scalar_dim, self._scalar_dim, sequence_q_dim, self._sequence_dim), + # tensor_name=TransformerKwargs.attention_mask + # ) + # kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + # (self._scalar_dim,), + # tensor_name=TransformerKwargs.attention_mask_value + # ) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef40..f05b5a89 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -59,8 +59,8 @@ def preprocess( # TODO: Adjust or reimplement. return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - @property - def loss_defs(self) -> list[LossDef]: + # @property + def get_loss_defs(self) -> list[LossDef]: # TODO: Adjust or reimplement. return super().loss_defs diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 24d5af4b..b0c9514f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -71,6 +71,9 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + # if self._config.transformer.diffusion.enabled: + # self._preprocessors.append(LLaDAMaskingPreprocessor(self._config.transformer, self._tensor_space)) + def get_output_layers(self) -> list[Layer]: return [ layer @@ -91,7 +94,7 @@ def get_output_layers(self) -> list[Layer]: self._tensor_space, prediction_distance=i, ) - if self._config.transformer.bidirectional_attention + if self._config.transformer.diffusion else LanguageModelHead( self._config, self._tensor_space, @@ -297,7 +300,32 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + + if batch.mask_indexes is not None: + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + # print(f'in masked-diffusion mode, batch.mask_indexes: {batch.mask_indexes}') + kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + device=self._tensor_space.distributed.device + ) + kwargs[LanguageModelKwargs.mask_probabilities] = batch.mask_probabilities.to( + device=self._tensor_space.distributed.device + ) + # Setup bidirection attention for diffusion should we set this in a preprocessor? BackupAttentionPreprocessor? + batch_size, seq_len = batch.token_ids.shape + attention_mask = torch.ones( + (batch_size, 1, seq_len, seq_len), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[TransformerKwargs.attention_mask] = attention_mask + kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( + -10000.0, device=self._tensor_space.distributed.device + ) + + # print(f"batch.token_ids aka inputs: {batch.token_ids}") + # print(f"labels: {labels}") for preprocessor in self._preprocessors: + # Update this include p_maks and mask index in kwargs preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) @@ -338,10 +366,10 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return {} # @property - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: loss_defs = [] - if self._config.transformer.bidirectional_attention: - # MLM loss for LLaDA-style training + if self._config.transformer.diffusion: + # Masked LM Loss for masked-diffusion training loss_defs.append( LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1, dtype=torch.float32) ) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a1c0c8bb..91e621e2 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -50,6 +50,7 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, + "diffusion": self._config.data.sampling.diffusion, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From cdc9c96d739d44f229af2538012200c2851602ad Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 11 Jun 2025 01:37:27 +0000 Subject: [PATCH 06/34] adding weighted loss --- fast_llm/data/data/gpt/data.py | 12 ++- fast_llm/functional/cross_entropy.py | 8 +- fast_llm/layers/language_model/head.py | 75 ++++++++------- fast_llm/layers/transformer/config.py | 67 -------------- fast_llm/layers/transformer/preprocessing.py | 97 -------------------- fast_llm/models/gpt/model.py | 4 +- 6 files changed, 61 insertions(+), 202 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 72218ea4..614913b2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,8 +65,15 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # Repeat the same mask probability for each token in the sequence mask_probabilities = p_mask[:, None].repeat(1, seq_len) # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] + + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] + # Last input token is dropped from the processing + # We should not mask it - mask_probabilities[:, 0] = 0.0 + # mask_probabilities[:, 0] = 0.0 # print(f"2 p_mask: {mask_probabilities} {mask_probabilities.shape}") # Generate random values for all tokens in the batch and only mask the positions\ @@ -82,6 +89,9 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # Replace masked tokens with the mask token ID to create input for the model. token_ids = torch.where(mask_indexes, mask_token_id, token_ids) + mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction. + mask_probabilities = mask_probabilities[:, :-1] # Remove the last token, which is not used for prediction. + if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index e87581f1..04765a1f 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -40,6 +40,7 @@ def fused_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, + avg_loss: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -71,7 +72,8 @@ def fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask - return per_sample_loss.mean(), grad + # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") + return per_sample_loss.mean() if avg_loss else per_sample_loss, grad @torch.compile @@ -140,6 +142,7 @@ def cross_entropy_forward_backward( group: ProcessGroup | None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, + avg_loss: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -147,6 +150,7 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + # print(f"CrossEntropyImpl: {implementation} {group}") if group: Assert.eq(implementation, CrossEntropyImpl.fused) return parallel_cross_entropy_forward_backward( @@ -154,5 +158,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor + logits, target, grad_output, logits_scale_factor=logits_scale_factor, avg_loss=avg_loss ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0fddd9d8..7d64ff4f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -366,15 +366,6 @@ def _logits_cross_entropy_forward_backward( losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - # can we just do this here? - # Input needs to be the masked input with masked tokens - # instead of forward pass - # this seems only a liner layer we need other layers too - # input_ = kwargs["noisy_batch"] - - # print(f"input_: {input_.shape} {input_}") - # print(f"labels: {labels.shape} {labels}") - logits, context = output_parallel_linear_forward( input_=input_, weight=weight, @@ -400,47 +391,65 @@ def _logits_cross_entropy_forward_backward( return logits * self._logits_scale_factor, None masked_indices = kwargs[LanguageModelKwargs.mask_indexes] - p_mask = kwargs[LanguageModelKwargs.mask_probabilities] - # print(f"masked_indices: {masked_indices.shape} {masked_indices}") + mask_probabilities = kwargs[LanguageModelKwargs.mask_probabilities] + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] - # The labels are already left shifted x = [0, 1, 2, 3, 4, 5] -> [1, 2, 3, 4, 5?] - # then the mask index on the label will give the correct tokens one-hot vector? - # this happens in model.py part of preprocessing - - # Question pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, + # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, # can it just learn to copy 3? i.e copy the next token to the masked? + # Yes. it will we need to include curruption to properly handle this. but it seems not to big looking at other CPT (diffuLlama) + + # print(f"context: {context[0].shape} {context}") + # print(f"logits {logits.shape} {logits}") + # print(f"labels: {labels.shape} {labels}") + # print(f"masked_indices: {masked_indices.shape} {masked_indices}") - # TODO: update loss only on masked tokens error from earlier linear layer missing tokens becuz of masking?? + # Compute CrossEntropy loss and weight each loss differently + # We use grad from all the input positions for backward pass. + # Find a way to weight the individual losses from each position seperatly, leave the grads alone. + # only get grads fron the masked positions ??? - masked_logits = ( - logits # logits[masked_indices[:, 1:]] # Skip the first token, which is the shift/context token - ) - # flatten the masked indices to match the logits - masked_indices_flt = masked_indices[:, 1:].flatten() - # print(f"masked_indices: {masked_indices_flt.shape} {masked_indices_flt}") - masked_labels = labels # labels[masked_indices_flt] - # print(f"p_mask {p_mask.shape} {masked_indices.shape}") - p_mask[masked_indices] - - # Compute MLM loss + # Currently by not doing any thing we have both AR loss and Diffusion loss treated equally. loss, grad = cross_entropy_forward_backward( - masked_logits.flatten(0, -2), - masked_labels, + logits.flatten(0, -2), + labels, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, + avg_loss=False, # Do not average the loss, we will do it later ) + # print(f"loss: {loss.shape} {loss}") + # print(f"grad: {grad.shape} {grad}") # print(f"loss: {loss.shape} {loss}") # Revisit this with the formula and what happens inside the cross_entropy_forward_backward # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 # loss = loss / masked_p # print(f"loss: {loss.shape} {loss}") - # revisit this with the formula and what happens inside the cross_entropy_forward_backward + # We need this when we have a way to weight the losses from each position differently. + # masked_logits = logits[masked_indices].unsqueeze(0) + # print(f"masked_logits: {masked_logits.shape} {masked_logits}") + # # flatten the masked indices to match the logits + # masked_indices_flt = masked_indices.flatten() + # masked_labels = labels[masked_indices_flt] + # print(f"masked_labels: {masked_labels.shape} {masked_labels}") + # p_mask[masked_indices] + + # Take only the losses and grads from the masked tokens/positions + masked_indices_flt = masked_indices.flatten() + masked_loss = loss[masked_indices_flt] + grad[masked_indices_flt] + # print("f masked_probabilities: ", mask_probabilities.shape, mask_probabilities, mask_probabilities.flatten()) + masked_loss = masked_loss / mask_probabilities.flatten()[masked_indices_flt] + + # compute per token loss by all tokens in the batch (tokens we dropped thinks they have 0 loss) # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275 - # loss = loss.sum() / (labels.shape[0] * labels.shape[1]) + masked_loss = masked_loss.sum() / labels.shape[0] del logits - return loss, output_parallel_linear_backward(grad, context) + # masked grad or full grad? + return masked_loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 58528376..167f291f 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -469,68 +469,6 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) -# @config_class() -# class DiffusionMaskingConfig(TransformerArchitectureConfig): -# """Configuration for diffusion-based masking in the transformer model. -# This config only contains model-specific parameters. For masking parameters, -# refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig.""" - -# enabled: bool = Field( -# default=False, desc="Whether to use diffusion-based masking during training", hint=FieldHint.feature -# ) -# bidirectional_attention: bool = Field( -# default=True, desc="Whether to use bidirectional attention for masked tokens", hint=FieldHint.feature -# ) - -# def _validate(self) -> None: -# super()._validate() - - -# @config_class() -# class DiffusionMaskingConfig(TransformerArchitectureConfig): -# """Configuration for diffusion-based masking during data preparation.""" - -# enabled: bool = Field( -# default=False, -# desc="Whether to use masked diffusion during training", -# hint=FieldHint.feature -# ) - -# epsilon: float = Field( -# default=1e-3, -# desc="Minimum masking probability", -# hint=FieldHint.performance, -# valid=check_field(Assert.gt, 0) -# ) - -# max_mask_prob: float = Field( -# default=0.15, -# desc="Maximum masking probability", -# hint=FieldHint.performance, -# valid=check_field(Assert.gt, 0) -# ) - -# pad_prob: float = Field( -# default=0.01, -# desc="Probability of padding tokens for 1% of samples", -# hint=FieldHint.optional, -# valid=check_field(Assert.geq, 0) -# ) - -# mask_token_id: int = Field( -# default=103, -# desc="Token ID to use for masking", -# hint=FieldHint.optional -# ) - -# def _validate(self) -> None: -# super()._validate() -# Assert.lt(self.epsilon, self.max_mask_prob) #, "epsilon must be less than max_mask_prob") -# Assert.lt(self.max_mask_prob, 1.0,) # "max_mask_prob must be less than 1.0") -# # if self.enabled: -# # Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") - - @config_class() class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) @@ -734,11 +672,6 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): desc="Use masked-diffusion for training.", hint=FieldHint.feature, ) - # diffusion: DiffusionMaskingConfig = Field( - # default_factory=DiffusionMaskingConfig, - # desc="Configuration for masked diffusion training.", - # hint=FieldHint.feature, - # ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 0279abaf..5b7a1185 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -340,100 +340,3 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ) kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() - - -class LLaDAMaskingPreprocessor(Preprocessor): - """Preprocessor for LLaDA-style masking with diffusion-based training.""" - - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = tensor_space.distributed_config - # print(f"tensor_space: {tensor_space._tensor_dims.keys()}") - self._scalar_dim = tensor_space.get_tensor_dim(DefaultDimNames.scalar) - # self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q) - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - """Apply LLaDA-style masking to the input sequence.""" - # Get diffusion config from dataset parameters - # print(f"kwargs: {kwargs.keys()}") - - print(f"1 batch: {type(batch)} {batch.shape}") - - diffusion_config = self._config.diffusion - if not diffusion_config.enabled: - return - - batch_size, seq_len = batch.shape - device = batch.device - mask_token_id = diffusion_config.mask_token_id - - # Generate a random tensor of batch size to seed masking probabilities - t = torch.rand((batch_size,), device=device) - - # Compute the mask probabilities for every sequence in the batch - p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon - - # Do we need to clamp at max_mask_prob? - # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) - - # Repeat the same mask probability for each token in the sequence - p_mask = p_mask[:, None].repeat(1, seq_len) - print(f"2 p_mask: {p_mask} {p_mask.shape}") - - # Generate random values for all tokens in the batch and only mask the positions\ - # where the value is smaller than the mask probability - masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask - - # Need further classification of this padding - 1% data to have partial sequences and padding - # if diffusion_config.pad_prob > 0: - # pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob - # if pad_mask.any(): - # masked_indices[pad_mask] = True - - # Replace masked tokens with the mask token ID to create input for the model. - noisy_batch = torch.where(masked_indices, mask_token_id, batch) - - kwargs["masked_indices"] = masked_indices - kwargs["p_mask"] = p_mask - kwargs["noisy_batch"] = noisy_batch - - # Bidirectional attention - all tokens can attend to all other tokens - attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) - - # if self._config.bidirectional_attention: - # # Bidirectional attention - all tokens can attend to all other tokens - # attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) - # else: - # # Causal attention - # attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_() - - kwargs[TransformerKwargs.attention_mask] = attention_mask - kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(-10000.0, device=device) - - # def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # """Define tensor metadata for masking tensors.""" - - # print(f"kwargs: {kwargs.keys()}") - # # Get diffusion config from dataset parameters - # sequence_q_dim = kwargs[TransformerKwargs.sequence_q_dim].size - # diffusion_config = kwargs['parameters'].diffusion - # if not diffusion_config.enabled: - # return - - # kwargs['masked_indices'] = TensorMeta.from_dims( - # (self._scalar_dim, sequence_q_dim), - # tensor_name='masked_indices' - # ) - # kwargs['p_mask'] = TensorMeta.from_dims( - # (self._scalar_dim, sequence_q_dim), - # tensor_name='p_mask' - # ) - # kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( - # (self._scalar_dim, self._scalar_dim, sequence_q_dim, self._sequence_dim), - # tensor_name=TransformerKwargs.attention_mask - # ) - # kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( - # (self._scalar_dim,), - # tensor_name=TransformerKwargs.attention_mask_value - # ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b0c9514f..83f294bf 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -322,8 +322,8 @@ def preprocess( -10000.0, device=self._tensor_space.distributed.device ) - # print(f"batch.token_ids aka inputs: {batch.token_ids}") - # print(f"labels: {labels}") + # print(f"batch.token_ids aka inputs: {batch.token_ids.shape} {batch.token_ids}") + # print(f"labels: {labels.shape} {labels}") for preprocessor in self._preprocessors: # Update this include p_maks and mask index in kwargs preprocessor.preprocess(tokens, kwargs) From d71e69369e4735fbfe49fec1cf9862e679d373c3 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 11 Jun 2025 17:27:57 +0000 Subject: [PATCH 07/34] clean up --- fast_llm/layers/language_model/head.py | 41 -------------------------- 1 file changed, 41 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 7d64ff4f..0be1493a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -315,47 +315,6 @@ def __init__( super().__init__(config, tensor_space, prediction_distance) self._loss_name = LanguageModelLossNames.mlm_loss - # def forward( - # self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None - # ) -> torch.Tensor: - # if isinstance(input_, TensorMeta): - # return TensorMeta.from_tensor_space( - # (DefaultDimNames.scalar,), - # self._tensor_space, - # tensor_name="Loss", - # reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - # ) - - # # Dropping MTP Stuff - # # if not self.is_last_head: - # # # MTP: split the stacked input - # # shared_hidden, input_ = torch.unbind(input_, dim=0) - - # # TODO: Pytorch copies the grads in backward for no reason (not sure if still the case) - # # TODO: Torch compile implementation sometimes break. - # # TODO: Double-check correctness, optimize a bit more. - # # TODO: Drop autograd entirely. - # # TODO: Skip cross-entropy backward if not needed. - - # print(f"forward input_: {input_.shape} {input_}") - - # # Input needs to be the masked input with masked tokens - # input_ = kwargs["noisy_batch"] - - # language_model_loss = self._forward(input_, kwargs, losses) - # if language_model_loss is not None: - # losses[self._loss_name].append(language_model_loss) - # # TODO: Return the model output when needed. - # # if self.is_last_head: - # # # Last head should return the loss for backward. - # return language_model_loss - # # else: - # # if self.training: - # # # Backward hook to compute the gradient of the loss - # # shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) - # # # MTP: Return shared_hidden to be used by the next head. - # # return shared_hidden - def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, From 072e6c4d27bedb59e2108e2dca3ca064c842683d Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 13 Jun 2025 17:35:29 +0000 Subject: [PATCH 08/34] add loss weight --- fast_llm/functional/cross_entropy.py | 18 ++++++++--- fast_llm/layers/language_model/head.py | 45 +++++++++++++++++++------- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 04765a1f..53ef7f30 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -39,8 +39,8 @@ def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, + loss_weight: torch.Tensor | None, logits_scale_factor: float = 1.0, - avg_loss: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -72,8 +72,16 @@ def fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask - # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") - return per_sample_loss.mean() if avg_loss else per_sample_loss, grad + print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape} {grad.shape if grad is not None else None} ") + if loss_weight is None: + return per_sample_loss.mean(), grad + else: + per_sample_loss = per_sample_loss * loss_weight.flatten() + print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") + loss_weight_expanded = loss_weight.reshape(-1, 1) + grad = grad * loss_weight_expanded if grad is not None else None + print(f"grad {grad.shape if grad is not None else None} ") + return per_sample_loss.mean(), grad @torch.compile @@ -142,7 +150,7 @@ def cross_entropy_forward_backward( group: ProcessGroup | None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, - avg_loss: bool = True, + loss_weight: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -158,5 +166,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor, avg_loss=avg_loss + logits, target, grad_output, logits_scale_factor=logits_scale_factor, loss_weight=loss_weight ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0be1493a..fadb01f1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -363,13 +363,36 @@ def _logits_cross_entropy_forward_backward( # print(f"context: {context[0].shape} {context}") # print(f"logits {logits.shape} {logits}") # print(f"labels: {labels.shape} {labels}") - # print(f"masked_indices: {masked_indices.shape} {masked_indices}") + print(f"masked_indices: {masked_indices.shape} {masked_indices}") + print(f"mask_probabilities: {mask_probabilities.shape} {mask_probabilities}") # Compute CrossEntropy loss and weight each loss differently # We use grad from all the input positions for backward pass. # Find a way to weight the individual losses from each position seperatly, leave the grads alone. # only get grads fron the masked positions ??? + last_weight = 0 + B = logits.shape[0] + p_mask = mask_probabilities[:, 0] # same repeated + print(f"p_mask: {p_mask.shape} {p_mask} B: {B}") + tmp = masked_indices[:, 1:] / p_mask[:, None] + print(f"{tmp.shape} {tmp}") + print(f"{torch.ones(B).shape}") + + loss_weight = torch.cat( + ( + # ar_weight * in_context[:, 1:] + # not implement yet + masked_indices[:, 1:] / p_mask[:, None], + # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet + (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), + # This may need some weighting in terms of masking... + # Let's do last_weight=0 for now? + ), + dim=1, + ).to(logits.dtype) + + print(f"loss_weight: {loss_weight.shape} {loss_weight}") + # Currently by not doing any thing we have both AR loss and Diffusion loss treated equally. loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), @@ -378,11 +401,11 @@ def _logits_cross_entropy_forward_backward( grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, - avg_loss=False, # Do not average the loss, we will do it later + loss_weight=loss_weight, # Do not average the loss, we will do it later ) - # print(f"loss: {loss.shape} {loss}") - # print(f"grad: {grad.shape} {grad}") + print(f"loss: {loss.shape} {loss}") + print(f"grad: {grad.shape} ") # print(f"loss: {loss.shape} {loss}") # Revisit this with the formula and what happens inside the cross_entropy_forward_backward # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 @@ -399,16 +422,16 @@ def _logits_cross_entropy_forward_backward( # p_mask[masked_indices] # Take only the losses and grads from the masked tokens/positions - masked_indices_flt = masked_indices.flatten() - masked_loss = loss[masked_indices_flt] - grad[masked_indices_flt] - # print("f masked_probabilities: ", mask_probabilities.shape, mask_probabilities, mask_probabilities.flatten()) - masked_loss = masked_loss / mask_probabilities.flatten()[masked_indices_flt] + # masked_indices_flt = masked_indices.flatten() + # masked_loss = loss[masked_indices_flt] + # grad[masked_indices_flt] + # # print("f masked_probabilities: ", mask_probabilities.shape, mask_probabilities, mask_probabilities.flatten()) + # masked_loss = masked_loss / mask_probabilities.flatten()[masked_indices_flt] # compute per token loss by all tokens in the batch (tokens we dropped thinks they have 0 loss) # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275 - masked_loss = masked_loss.sum() / labels.shape[0] + # masked_loss = masked_loss.sum() / labels.shape[0] del logits # masked grad or full grad? - return masked_loss, output_parallel_linear_backward(grad, context) + return loss, output_parallel_linear_backward(grad, context) From 6127544d76e09480b5fdb889e9d0f1c8512bb02e Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 13 Jun 2025 20:47:37 +0000 Subject: [PATCH 09/34] adding updates to p_mask --- fast_llm/data/data/gpt/data.py | 7 ++++--- fast_llm/functional/cross_entropy.py | 6 +++--- fast_llm/layers/language_model/head.py | 22 +++++++++++----------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 614913b2..5bdb19fc 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -63,7 +63,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) # Repeat the same mask probability for each token in the sequence - mask_probabilities = p_mask[:, None].repeat(1, seq_len) + # mask_probabilities = p_mask[:, None].repeat(1, seq_len) # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] # index [0, 1, 2, 3, 4, 5] -> @@ -78,7 +78,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # Generate random values for all tokens in the batch and only mask the positions\ # where the value is smaller than the mask probability - mask_indexes = torch.rand((batch_size, seq_len)) < mask_probabilities + mask_indexes = torch.rand((batch_size, seq_len)) < p_mask[:, None] # Need further classification of this padding - 1% data to have partial sequences and padding # if diffusion_config.pad_prob > 0: @@ -90,7 +90,8 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling token_ids = torch.where(mask_indexes, mask_token_id, token_ids) mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction. - mask_probabilities = mask_probabilities[:, :-1] # Remove the last token, which is not used for prediction. + # mask_probabilities = mask_probabilities[:, :-1] # Remove the last token, which is not used for prediction. + mask_probabilities = p_mask if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 53ef7f30..1c7e7e21 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -72,15 +72,15 @@ def fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask - print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape} {grad.shape if grad is not None else None} ") + # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape} {grad.shape if grad is not None else None} ") if loss_weight is None: return per_sample_loss.mean(), grad else: per_sample_loss = per_sample_loss * loss_weight.flatten() - print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") + # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None - print(f"grad {grad.shape if grad is not None else None} ") + # print(f"grad {grad.shape if grad is not None else None} ") return per_sample_loss.mean(), grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index fadb01f1..d1887ce6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -350,7 +350,7 @@ def _logits_cross_entropy_forward_backward( return logits * self._logits_scale_factor, None masked_indices = kwargs[LanguageModelKwargs.mask_indexes] - mask_probabilities = kwargs[LanguageModelKwargs.mask_probabilities] + p_mask = kwargs[LanguageModelKwargs.mask_probabilities] # index [0, 1, 2, 3, 4, 5] -> # The labels are already left shifted x = [A, B, C, D, E, F] -> # embd = [A, B, C, D, E] @@ -363,8 +363,8 @@ def _logits_cross_entropy_forward_backward( # print(f"context: {context[0].shape} {context}") # print(f"logits {logits.shape} {logits}") # print(f"labels: {labels.shape} {labels}") - print(f"masked_indices: {masked_indices.shape} {masked_indices}") - print(f"mask_probabilities: {mask_probabilities.shape} {mask_probabilities}") + # print(f"masked_indices: {masked_indices.shape} {masked_indices}") + # print(f"mask_probabilities: {mask_probabilities.shape} {mask_probabilities}") # Compute CrossEntropy loss and weight each loss differently # We use grad from all the input positions for backward pass. @@ -373,11 +373,11 @@ def _logits_cross_entropy_forward_backward( last_weight = 0 B = logits.shape[0] - p_mask = mask_probabilities[:, 0] # same repeated - print(f"p_mask: {p_mask.shape} {p_mask} B: {B}") - tmp = masked_indices[:, 1:] / p_mask[:, None] - print(f"{tmp.shape} {tmp}") - print(f"{torch.ones(B).shape}") + # p_mask = mask_probabilities[:, 0] # same repeated + # print(f"p_mask: {p_mask.shape} {p_mask} B: {B}") + # tmp = masked_indices[:, 1:] / p_mask[:, None] + # print(f"{tmp.shape} {tmp}") + # print(f"{torch.ones(B).shape}") loss_weight = torch.cat( ( @@ -391,7 +391,7 @@ def _logits_cross_entropy_forward_backward( dim=1, ).to(logits.dtype) - print(f"loss_weight: {loss_weight.shape} {loss_weight}") + # print(f"loss_weight: {loss_weight.shape} {loss_weight}") # Currently by not doing any thing we have both AR loss and Diffusion loss treated equally. loss, grad = cross_entropy_forward_backward( @@ -404,8 +404,8 @@ def _logits_cross_entropy_forward_backward( loss_weight=loss_weight, # Do not average the loss, we will do it later ) - print(f"loss: {loss.shape} {loss}") - print(f"grad: {grad.shape} ") + # print(f"loss: {loss.shape} {loss}") + # print(f"grad: {grad.shape} ") # print(f"loss: {loss.shape} {loss}") # Revisit this with the formula and what happens inside the cross_entropy_forward_backward # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 From 1cf15a8249e11f47aa1ba3b38d52271e9209879d Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Mon, 16 Jun 2025 13:48:45 +0000 Subject: [PATCH 10/34] update error mgs --- fast_llm/functional/cross_entropy.py | 5 +++++ fast_llm/layers/language_model/head.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 1c7e7e21..0c478770 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -12,6 +12,7 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, + loss_weight: torch.Tensor | None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -21,6 +22,7 @@ def torch_cross_entropy_forward_backward( TODO: loss masking only works for this method if the masking index is set to -100. """ # Torch compile doesn't understand this. + assert loss_weight is None, "Loss weight not supported in torch cross-entropy implementation." with torch.enable_grad(): logits_ = logits.float().detach().requires_grad_() if logits_scale_factor != 1.0: @@ -90,6 +92,7 @@ def parallel_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, group: ProcessGroup, + loss_weight: torch.Tensor | None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -98,6 +101,8 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions + assert loss_weight is None, "Loss weight not supported in parallel cross-entropy implementation." + loss_mask = target >= 0 target = target.unsqueeze(1) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d1887ce6..61d09f88 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -351,7 +351,7 @@ def _logits_cross_entropy_forward_backward( masked_indices = kwargs[LanguageModelKwargs.mask_indexes] p_mask = kwargs[LanguageModelKwargs.mask_probabilities] - # index [0, 1, 2, 3, 4, 5] -> + # index [0, 1, 2, 3, 4, 5] -> # The labels are already left shifted x = [A, B, C, D, E, F] -> # embd = [A, B, C, D, E] # label = [B, C, D, E, F] From f7a46d777dc62228a560027f8d92941c70054b3b Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Mon, 16 Jun 2025 20:25:56 +0000 Subject: [PATCH 11/34] add comments and clean up --- fast_llm/data/data/gpt/data.py | 9 +--- fast_llm/functional/cross_entropy.py | 6 +-- fast_llm/layers/language_model/head.py | 57 ++++---------------------- fast_llm/models/gpt/model.py | 7 ---- 4 files changed, 10 insertions(+), 69 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 5bdb19fc..83974645 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -57,13 +57,11 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling t = torch.rand((batch_size,)) # Compute the mask probabilities for every sequence in the batch - p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon # Do we need to clamp at max_mask_prob? # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) - # Repeat the same mask probability for each token in the sequence - # mask_probabilities = p_mask[:, None].repeat(1, seq_len) # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] # index [0, 1, 2, 3, 4, 5] -> @@ -72,10 +70,6 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # label = [B, C, D, E, F] # Last input token is dropped from the processing - # We should not mask it - # mask_probabilities[:, 0] = 0.0 - # print(f"2 p_mask: {mask_probabilities} {mask_probabilities.shape}") - # Generate random values for all tokens in the batch and only mask the positions\ # where the value is smaller than the mask probability mask_indexes = torch.rand((batch_size, seq_len)) < p_mask[:, None] @@ -90,7 +84,6 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling token_ids = torch.where(mask_indexes, mask_token_id, token_ids) mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction. - # mask_probabilities = mask_probabilities[:, :-1] # Remove the last token, which is not used for prediction. mask_probabilities = p_mask if sampling_parameters.use_loss_masking_spans: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0c478770..0ba302e1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -74,15 +74,13 @@ def fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask - # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape} {grad.shape if grad is not None else None} ") if loss_weight is None: return per_sample_loss.mean(), grad else: per_sample_loss = per_sample_loss * loss_weight.flatten() - # print(f"per_sample_loss: {per_sample_loss} {per_sample_loss.shape}") loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None - # print(f"grad {grad.shape if grad is not None else None} ") + # Avg across all the tokens. return per_sample_loss.mean(), grad @@ -155,7 +153,7 @@ def cross_entropy_forward_backward( group: ProcessGroup | None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, - loss_weight: bool = True, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 61d09f88..1be82287 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -333,8 +333,6 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._parallel_embeddings, ) - # print(f"logits {logits.shape} {logits}") - if self._z_loss_factor > 0.0: logits = z_loss( logits, @@ -358,26 +356,11 @@ def _logits_cross_entropy_forward_backward( # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, # can it just learn to copy 3? i.e copy the next token to the masked? - # Yes. it will we need to include curruption to properly handle this. but it seems not to big looking at other CPT (diffuLlama) - - # print(f"context: {context[0].shape} {context}") - # print(f"logits {logits.shape} {logits}") - # print(f"labels: {labels.shape} {labels}") - # print(f"masked_indices: {masked_indices.shape} {masked_indices}") - # print(f"mask_probabilities: {mask_probabilities.shape} {mask_probabilities}") - - # Compute CrossEntropy loss and weight each loss differently - # We use grad from all the input positions for backward pass. - # Find a way to weight the individual losses from each position seperatly, leave the grads alone. - # only get grads fron the masked positions ??? + # Yes. We need to drop those position from loss if the next token is not masked + # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) last_weight = 0 B = logits.shape[0] - # p_mask = mask_probabilities[:, 0] # same repeated - # print(f"p_mask: {p_mask.shape} {p_mask} B: {B}") - # tmp = masked_indices[:, 1:] / p_mask[:, None] - # print(f"{tmp.shape} {tmp}") - # print(f"{torch.ones(B).shape}") loss_weight = torch.cat( ( @@ -385,15 +368,11 @@ def _logits_cross_entropy_forward_backward( masked_indices[:, 1:] / p_mask[:, None], # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), - # This may need some weighting in terms of masking... - # Let's do last_weight=0 for now? + # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later ), dim=1, ).to(logits.dtype) - # print(f"loss_weight: {loss_weight.shape} {loss_weight}") - - # Currently by not doing any thing we have both AR loss and Diffusion loss treated equally. loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), labels, @@ -404,34 +383,12 @@ def _logits_cross_entropy_forward_backward( loss_weight=loss_weight, # Do not average the loss, we will do it later ) - # print(f"loss: {loss.shape} {loss}") - # print(f"grad: {grad.shape} ") - # print(f"loss: {loss.shape} {loss}") - # Revisit this with the formula and what happens inside the cross_entropy_forward_backward + # This happens with the loss_weight. # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 - # loss = loss / masked_p - # print(f"loss: {loss.shape} {loss}") - - # We need this when we have a way to weight the losses from each position differently. - # masked_logits = logits[masked_indices].unsqueeze(0) - # print(f"masked_logits: {masked_logits.shape} {masked_logits}") - # # flatten the masked indices to match the logits - # masked_indices_flt = masked_indices.flatten() - # masked_labels = labels[masked_indices_flt] - # print(f"masked_labels: {masked_labels.shape} {masked_labels}") - # p_mask[masked_indices] - - # Take only the losses and grads from the masked tokens/positions - # masked_indices_flt = masked_indices.flatten() - # masked_loss = loss[masked_indices_flt] - # grad[masked_indices_flt] - # # print("f masked_probabilities: ", mask_probabilities.shape, mask_probabilities, mask_probabilities.flatten()) - # masked_loss = masked_loss / mask_probabilities.flatten()[masked_indices_flt] - - # compute per token loss by all tokens in the batch (tokens we dropped thinks they have 0 loss) + + # Compute per token loss by avg across all tokens in the batch (tokens we ignore are assumed to have a 0 loss still counted towards the average) + # done inside the cross-entropy function # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275 - # masked_loss = masked_loss.sum() / labels.shape[0] del logits - # masked grad or full grad? return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 83f294bf..c2e7d50f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -71,9 +71,6 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - # if self._config.transformer.diffusion.enabled: - # self._preprocessors.append(LLaDAMaskingPreprocessor(self._config.transformer, self._tensor_space)) - def get_output_layers(self) -> list[Layer]: return [ layer @@ -303,7 +300,6 @@ def preprocess( if batch.mask_indexes is not None: # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs - # print(f'in masked-diffusion mode, batch.mask_indexes: {batch.mask_indexes}') kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( device=self._tensor_space.distributed.device ) @@ -322,10 +318,7 @@ def preprocess( -10000.0, device=self._tensor_space.distributed.device ) - # print(f"batch.token_ids aka inputs: {batch.token_ids.shape} {batch.token_ids}") - # print(f"labels: {labels.shape} {labels}") for preprocessor in self._preprocessors: - # Update this include p_maks and mask index in kwargs preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) From 01a683bfc844919cec3441b8c8cf4666edbd4747 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 18 Jun 2025 14:54:28 +0000 Subject: [PATCH 12/34] fx merge errors --- fast_llm/engine/base_model/base_model.py | 1 - fast_llm/engine/schedule/runner.py | 1 - fast_llm/layers/transformer/preprocessing.py | 9 +++------ 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 1f4921ae..51252f03 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -137,7 +137,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - # @property @abc.abstractmethod def get_loss_defs(self) -> list[LossDef]: pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index f5f0e111..f2b302de 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,6 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - print(f"base_model={type(self._multi_stage.base_model)} {self._multi_stage.base_model.get_loss_defs()} ") self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 031ccd30..dc3ddeb5 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -15,8 +15,8 @@ class BackupAttentionPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor | None - _mask_value: torch.Tensor | None + _mask: torch.Tensor + _mask_value: torch.Tensor _tensor_cache_max_sequence_length: int = -1 def __init__( @@ -24,7 +24,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - super().__init__() self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -70,8 +69,6 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - - # can we add a bidirectional attention here? kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: @@ -89,7 +86,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), tensor_name=TransformerKwargs.attention_mask_value, - dtype=self._distributed_config.training_dtype.torch, + dtype=self._tensor_space.distributed_config.training_dtype.torch, ) From ba913e1987cdfe9e2b46e37b017c0c28990c7e52 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 18 Jun 2025 15:45:31 +0000 Subject: [PATCH 13/34] fix merge issues --- fast_llm/layers/__init__.py | 1 - fast_llm/models/custom/model.py | 3 +- fast_llm/models/gpt/model.py | 89 ++++++++++++++++----------------- 3 files changed, 44 insertions(+), 49 deletions(-) diff --git a/fast_llm/layers/__init__.py b/fast_llm/layers/__init__.py index 8b137891..e69de29b 100644 --- a/fast_llm/layers/__init__.py +++ b/fast_llm/layers/__init__.py @@ -1 +0,0 @@ - diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index f05b5a89..8b1f1635 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -59,10 +59,9 @@ def preprocess( # TODO: Adjust or reimplement. return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - # @property def get_loss_defs(self) -> list[LossDef]: # TODO: Adjust or reimplement. - return super().loss_defs + return super().get_loss_defs() class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 3fd3eb40..d1e01511 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -64,34 +64,37 @@ def __init__( self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: - return [ - layer - for i in range(self._config.prediction_heads) - for layer in [ - TransformerLayer( - self._config.transformer, - self._tensor_space, - # TODO MTP: which index? - layer_index=self._config.transformer.num_layers, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, - ), - ( + layers = [] + for i in range(self._config.prediction_heads): + if i > 0: + layers.append( + TransformerLayer( + self._config.transformer, + self._tensor_space, + # TODO MTP: which index? + layer_index=max(self._config.transformer.num_layers + i, 1), + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=i < self._config.prediction_heads - 1, + ) + ) + if self._config.transformer.diffusion: + layers.append( MLMHead( self._config, self._tensor_space, prediction_distance=i, ) - if self._config.transformer.diffusion - else LanguageModelHead( + ) + else: + layers.append( + LanguageModelHead( self._config, self._tensor_space, prediction_distance=i, ) - ), - ] - ] + ) + return layers def get_layers(self) -> list[Layer]: return [ @@ -392,44 +395,38 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - # @property def get_loss_defs(self) -> list[LossDef]: loss_defs = [] - if self._config.transformer.diffusion: - # Masked LM Loss for masked-diffusion training - loss_defs.append( - LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1, dtype=torch.float32) - ) - else: - # Standard language modeling loss + if ( + self._config.transformer.num_experts > 1 + and self._config.transformer.expert_routing_type == RoutingType.topk + ): loss_defs.append( LossDef( - name=LanguageModelLossNames.language_model_loss, - formatted_name="Language Model Loss", - count=1, - dtype=torch.float32, + name=TransformerLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=self._config.transformer.num_layers, ) ) - - if self._config.transformer.num_experts > 1: - if self._config.transformer.expert_routing_type == RoutingType.topk: - loss_defs.append( - LossDef( - name=TransformerLossNames.load_balancing_loss, - formatted_name="Load Balancing Loss", - count=1, - dtype=torch.float32, - ) - ) - if self._config.transformer.expert_z_loss_coefficient > 0: + if self._config.transformer.expert_z_loss_coefficient: loss_defs.append( LossDef( name=TransformerLossNames.router_z_loss, - formatted_name="Router Z Loss", - count=1, - dtype=torch.float32, + formatted_name="router z loss", + count=self._config.transformer.num_layers, ) ) + if self._config.logit_z_loss: + LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + + for i in range(self._config.prediction_heads): + loss_defs.append( + LossDef( + name=LanguageModelLossNames.multi_token_prediction_loss(i), + formatted_name=f"language model loss {i}", + count=1, + ) + ) return loss_defs From 6c0c72d136fb0d6a9079f895d8070e496f60f77d Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 18 Jun 2025 16:18:39 +0000 Subject: [PATCH 14/34] register mask config --- fast_llm/data/dataset/gpt/config.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 6b01f247..4b78ae34 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -44,7 +44,7 @@ class ShufflingType(str, enum.Enum): legacy = "legacy" -@config_class() +@config_class(registry=True) class DiffusionMaskingConfig(Config): """Configuration for diffusion-based masking during data preparation.""" @@ -75,9 +75,7 @@ def _validate(self) -> None: Assert.lt( self.max_mask_prob, 1.0, - ) # "max_mask_prob must be less than 1.0") - # if self.enabled: - # Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled") + ) @config_class() @@ -99,7 +97,6 @@ class GPTSamplingConfig(SamplingConfig): hint=FieldHint.feature, ) diffusion: DiffusionMaskingConfig = Field( - default_factory=DiffusionMaskingConfig, desc="Configuration for diffusion-based masking during data preparation.", hint=FieldHint.feature, ) @@ -119,11 +116,7 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 - diffusion: DiffusionMaskingConfig = Field( - default_factory=DiffusionMaskingConfig, - desc="Configuration for diffusion-based masking during data preparation. Will be copied from GPTSamplingConfig during ", - hint=FieldHint.feature, - ) + diffusion: DiffusionMaskingConfig @dataclasses.dataclass(kw_only=True) From 324549679165d8f6c360117917fbe6114e095f63 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 18 Jun 2025 18:19:42 +0000 Subject: [PATCH 15/34] fx merge issues --- fast_llm/functional/cross_entropy.py | 3 ++- fast_llm/layers/language_model/head.py | 29 +++++++++----------------- fast_llm/models/gpt/model.py | 4 ++++ 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 78d0c359..7fbf285c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -158,6 +158,7 @@ def _fused_cross_entropy_forward_backward( all_reduce(loss, op=ReduceOp.MEAN, group=group) return loss, grad else: + # Weight every token loss by the loss weight. Before averaging. per_sample_loss = per_sample_loss * loss_weight.flatten() loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None @@ -206,5 +207,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor, loss_weight=loss_weight + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, loss_weight=loss_weight ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 5efd51b4..5c19e956 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -378,13 +378,17 @@ def __init__( def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, - labels: torch.Tensor | None, + target: torch.Tensor | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert target is not None, "MLM head requires target labels" + assert loss_mask is None, "MLM head does not support loss mask" + logits, context = output_parallel_linear_forward( input_=input_, weight=weight, @@ -393,20 +397,6 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._parallel_embeddings, ) - if self._z_loss_factor > 0.0: - logits = z_loss( - logits, - self._z_loss_factor, - self.training, - grad_output, - losses, - LanguageModelLossNames.z_loss, - logits_scale_factor=self._logits_scale_factor, - ) - - if labels is None: - return logits * self._logits_scale_factor, None - masked_indices = kwargs[LanguageModelKwargs.mask_indexes] p_mask = kwargs[LanguageModelKwargs.mask_probabilities] # index [0, 1, 2, 3, 4, 5] -> @@ -434,13 +424,14 @@ def _logits_cross_entropy_forward_backward( ).to(logits.dtype) loss, grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - labels, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, grad_output=grad_output, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, - loss_weight=loss_weight, # Do not average the loss, we will do it later + loss_weight=loss_weight, ) # This happens with the loss_weight. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d1e01511..b08fc54c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -419,6 +419,10 @@ def get_loss_defs(self) -> list[LossDef]: if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + if self._config.transformer.diffusion: + # Masked LM Loss for masked-diffusion training + loss_defs.append(LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1)) + for i in range(self._config.prediction_heads): loss_defs.append( LossDef( From 4ad0bc17ca34ad957476acdd07cb22395404c9e2 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Mon, 23 Jun 2025 19:06:19 +0000 Subject: [PATCH 16/34] fix labels --- fast_llm/data/data/gpt/data.py | 8 +++++--- fast_llm/models/gpt/model.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6a640ebe..264ec212 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,10 +32,11 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None - mask_indexes: torch.Tensor | None = None - mask_probabilities: torch.Tensor | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None + mask_indexes: torch.Tensor | None = None + mask_probabilities: torch.Tensor | None = None + masked_token_ids: torch.Tensor | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -83,7 +84,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling # mask_indexes[pad_mask] = True # Replace masked tokens with the mask token ID to create input for the model. - token_ids = torch.where(mask_indexes, mask_token_id, token_ids) + masked_token_ids = torch.where(mask_indexes, mask_token_id, token_ids) mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction. mask_probabilities = p_mask @@ -109,6 +110,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling rejected_spans=stacked_rejected_spans, mask_indexes=mask_indexes, mask_probabilities=mask_probabilities, + masked_token_ids=masked_token_ids, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 970b86d7..bea192bc 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -353,6 +353,9 @@ def preprocess( -10000.0, device=self._tensor_space.distributed.device ) + # set token ids to masked tokens + batch.token_ids = batch.masked_token_ids + kwargs.update(reference_logits[i]) for preprocessor in self._preprocessors: From acacfe3d45f2477e0eae04f65e071907c9ed3480 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Mon, 23 Jun 2025 21:04:36 +0000 Subject: [PATCH 17/34] drop old tests --- tests/test_masking.py | 184 ----------------------------------------- tests/test_mlm_loss.py | 149 --------------------------------- 2 files changed, 333 deletions(-) delete mode 100644 tests/test_masking.py delete mode 100644 tests/test_mlm_loss.py diff --git a/tests/test_masking.py b/tests/test_masking.py deleted file mode 100644 index 30ddd69d..00000000 --- a/tests/test_masking.py +++ /dev/null @@ -1,184 +0,0 @@ -import pytest -import torch - -from fast_llm.layers.language_model.preprocessing import LLaDAMaskingPreprocessor -from fast_llm.layers.transformer.config import DiffusionMaskingConfig - - -@pytest.fixture -def masking_config(): - return DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - -def test_masking_basic(): - config = DiffusionMaskingConfig( - enabled=True, - epsilon=0.15, # 15% minimum masking - max_mask_prob=0.5, # 50% maximum masking - pad_prob=0.1, - mask_token_id=103 - ) - - preprocessor = LLaDAMaskingPreprocessor(config) - - batch_size = 4 - seq_len = 10 - input_ids = torch.randint(0, 1000, (batch_size, seq_len)) - - input_ids[:, -2:] = 0 # Add padding at the end - - outputs = preprocessor(input_ids) - - masked_indices = outputs['masked_indices'] - p_mask = outputs['p_mask'] - masked_input = outputs['input_ids'] - - assert masked_indices.shape == input_ids.shape - assert p_mask.shape == input_ids.shape - assert masked_input.shape == input_ids.shape - - padding_positions = (input_ids == 0) - assert not masked_indices[padding_positions].any() - assert (p_mask[padding_positions] == 0).all() - - non_pad_positions = ~padding_positions - assert (p_mask[non_pad_positions] >= config.epsilon).all() - assert (p_mask[non_pad_positions] <= config.max_mask_prob).all() - - assert (masked_input[masked_indices] == config.mask_token_id).all() - - unmasked_positions = ~masked_indices & non_pad_positions - assert (masked_input[unmasked_positions] == input_ids[unmasked_positions]).all() - - -def test_masking_edge_cases(): - config = DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - preprocessor = LLaDAMaskingPreprocessor(config) - - input_ids = torch.randint(0, 1000, (1, 5)) - outputs = preprocessor(input_ids) - assert outputs['masked_indices'].shape == (1, 5) - assert outputs['p_mask'].shape == (1, 5) - - input_ids = torch.zeros(2, 4) - outputs = preprocessor(input_ids) - assert not outputs['masked_indices'].any() # No tokens should be masked - assert (outputs['p_mask'] == 0).all() # All masking probs should be 0 - - input_ids = torch.randint(1, 1000, (2, 4)) # All tokens are non-padding - outputs = preprocessor(input_ids) - assert outputs['masked_indices'].any() # Some tokens should be masked - assert (outputs['p_mask'] >= config.epsilon).all() # All probs should be >= epsilon - - input_ids = torch.randint(1, 1000, (1, 1)) - outputs = preprocessor(input_ids) - assert outputs['masked_indices'].shape == (1, 1) - assert outputs['p_mask'].shape == (1, 1) - - -def test_masking_probabilities(): - config = DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - preprocessor = LLaDAMaskingPreprocessor(config) - - input_ids = torch.ones(3, 8) - input_ids[0, :] = torch.arange(1, 9) # Increasing sequence - input_ids[1, :] = torch.arange(8, 0, -1) # Decreasing sequence - input_ids[2, :] = 1 # Constant sequence - - n_trials = 100 - mask_counts = torch.zeros_like(input_ids) - - for _ in range(n_trials): - outputs = preprocessor(input_ids) - mask_counts += outputs['masked_indices'].float() - - empirical_probs = mask_counts / n_trials - - assert (empirical_probs >= config.epsilon - 0.05).all() # Allow small deviation - assert (empirical_probs <= config.max_mask_prob + 0.05).all() - - -def test_masking_deterministic(): - config = DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - preprocessor = LLaDAMaskingPreprocessor(config) - - - torch.manual_seed(42) - - - input_ids = torch.randint(1, 1000, (2, 6)) - - torch.manual_seed(42) - outputs1 = preprocessor(input_ids) - - torch.manual_seed(42) - outputs2 = preprocessor(input_ids) - - assert torch.equal(outputs1['masked_indices'], outputs2['masked_indices']) - assert torch.equal(outputs1['p_mask'], outputs2['p_mask']) - assert torch.equal(outputs1['input_ids'], outputs2['input_ids']) - - -def test_masking_config_validation(): - with pytest.raises(ValueError): - DiffusionMaskingConfig( - enabled=True, - epsilon=-0.1, # Invalid negative value - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - with pytest.raises(ValueError): - DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=1.5, # Invalid value > 1 - pad_prob=0.1, - mask_token_id=103 - ) - - with pytest.raises(ValueError): - DiffusionMaskingConfig( - enabled=True, - epsilon=0.6, # Greater than max_mask_prob - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - - with pytest.raises(ValueError): - DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=-0.1, # Invalid negative value - mask_token_id=103 - ) \ No newline at end of file diff --git a/tests/test_mlm_loss.py b/tests/test_mlm_loss.py deleted file mode 100644 index f7913c63..00000000 --- a/tests/test_mlm_loss.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -from fast_llm.layers.language_model.head import MLMHead -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace, DefaultDimNames, TensorDim -from fast_llm.layers.transformer.config import TransformerConfig, DiffusionMaskingConfig -from fast_llm.engine.distributed.config import DistributedConfig - - -@pytest.fixture -def mlm_config(): - transformer_config = TransformerConfig( - hidden_size=768, - num_layers=12, - num_attention_heads=12, - diffusion=DiffusionMaskingConfig( - enabled=True, - epsilon=0.1, - max_mask_prob=0.5, - pad_prob=0.1, - mask_token_id=103 - ) - ) - - return LanguageModelBaseConfig( - vocab_size=30522, - transformer=transformer_config, - tie_word_embeddings=False, - parallel_embeddings=False, - prediction_heads=1 - ) - - -@pytest.fixture -def tensor_space(): - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config) - tensor_space.add_tensor_dim(DefaultDimNames.scalar, 1) - tensor_space.add_tensor_dim("hidden", 768) - tensor_space.add_tensor_dim("vocab", 30522) - return tensor_space - - -def test_mlm_loss_computation(mlm_config, tensor_space): - mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) - - batch_size = 4 - seq_len = 8 - hidden_size = 768 - vocab_size = 30522 - - hidden_states = torch.randn(batch_size, seq_len, hidden_size) - - masked_indices = torch.zeros(batch_size, seq_len, dtype=torch.bool) - masked_indices[:, [2, 5]] = True # Mask positions 2 and 5 in each sequence - - p_mask = torch.full((batch_size, seq_len), 0.15) # 15% masking probability - - labels = torch.randint(0, vocab_size, (batch_size, seq_len)) - - kwargs = { - 'masked_indices': masked_indices, - 'p_mask': p_mask, - 'labels': labels - } - - losses = {} - output = mlm_head(hidden_states, kwargs, losses) - - - assert output is not None - assert isinstance(output, torch.Tensor) - assert output.requires_grad - - - assert losses # losses dictionary should not be empty - - # Test with no masked positions - kwargs['masked_indices'] = torch.zeros_like(masked_indices) - losses = {} - output_no_masks = mlm_head(hidden_states, kwargs, losses) - assert output_no_masks is not None - - # Test with all positions masked - kwargs['masked_indices'] = torch.ones_like(masked_indices) - losses = {} - output_all_masked = mlm_head(hidden_states, kwargs, losses) - assert output_all_masked is not None - - -def test_mlm_loss_edge_cases(mlm_config, tensor_space): - mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) - - hidden_states = torch.randn(1, 4, 768) - masked_indices = torch.zeros(1, 4, dtype=torch.bool) - masked_indices[0, 1] = True - p_mask = torch.full((1, 4), 0.15) - labels = torch.randint(0, 30522, (1, 4)) - - kwargs = { - 'masked_indices': masked_indices, - 'p_mask': p_mask, - 'labels': labels - } - - losses = {} - output = mlm_head(hidden_states, kwargs, losses) - assert output is not None - - p_mask = torch.full((1, 4), 0.01) - kwargs['p_mask'] = p_mask - losses = {} - output = mlm_head(hidden_states, kwargs, losses) - assert output is not None - - p_mask = torch.full((1, 4), 0.5) # max_mask_prob from config - kwargs['p_mask'] = p_mask - losses = {} - output = mlm_head(hidden_states, kwargs, losses) - assert output is not None - - -def test_mlm_loss_backward(mlm_config, tensor_space): - mlm_head = MLMHead(mlm_config, tensor_space, prediction_distance=0) - - hidden_states = torch.randn(2, 6, 768, requires_grad=True) - masked_indices = torch.zeros(2, 6, dtype=torch.bool) - masked_indices[:, [1, 4]] = True - p_mask = torch.full((2, 6), 0.15) - labels = torch.randint(0, 30522, (2, 6)) - - kwargs = { - 'masked_indices': masked_indices, - 'p_mask': p_mask, - 'labels': labels - } - - losses = {} - output = mlm_head(hidden_states, kwargs, losses) - - output.backward() - - assert hidden_states.grad is not None - assert not torch.isnan(hidden_states.grad).any() - assert not torch.isinf(hidden_states.grad).any() - - assert hidden_states.grad.shape == hidden_states.shape \ No newline at end of file From 2a06ed425ec6e964e388535462d129fc5ae7aa68 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Tue, 24 Jun 2025 02:00:08 +0000 Subject: [PATCH 18/34] tmp fix --- fast_llm/layers/transformer/rotary/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index ce7af88d..e9c79dd4 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -3,7 +3,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import DEFAULT, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig @@ -112,8 +112,8 @@ class YarnRotaryConfig(DefaultRotaryConfig): # TODO: Add descriptions. scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - attention_factor: None | float = Field( - default=None, + attention_factor: float | None = Field( + default=DEFAULT, hint=FieldHint.feature, ) beta_fast: float = Field( @@ -128,8 +128,8 @@ class YarnRotaryConfig(DefaultRotaryConfig): def _validate(self) -> None: if self.attention_factor is None: - with self._set_implicit_default(): - self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 + # with self._set_implicit_default(): + self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": From dd68d28e73c432ea660c03b57b83755bb067b9f4 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Tue, 24 Jun 2025 13:58:59 +0000 Subject: [PATCH 19/34] fx tests --- fast_llm/data/data/gpt/data.py | 1 + fast_llm/engine/evaluation/evaluator.py | 2 +- fast_llm/functional/triton/cross_entropy.py | 3 +++ fast_llm/layers/transformer/rotary/config.py | 4 ++-- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 264ec212..896a2579 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -46,6 +46,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling sequence_lengths = None mask_indexes = None mask_probabilities = None + masked_token_ids = None token_ids = torch.from_numpy(stacked_ids) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index f07a8c48..8f963696 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -119,7 +119,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 8cb59c85..c257edbb 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -125,6 +125,7 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -133,6 +134,8 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + assert loss_weight is None, "Loss weight not supported in triton cross-entropy implementation." + # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index e9c79dd4..033cbb35 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -3,7 +3,7 @@ import typing import warnings -from fast_llm.config import DEFAULT, Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig @@ -113,7 +113,7 @@ class YarnRotaryConfig(DefaultRotaryConfig): # TODO: Add descriptions. scale_factor: float = Field(default=8.0, hint=FieldHint.feature) attention_factor: float | None = Field( - default=DEFAULT, + default=None, hint=FieldHint.feature, ) beta_fast: float = Field( From e0a7c8081d781782b96793f1665d30ccca20fca9 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 25 Jun 2025 14:41:09 +0000 Subject: [PATCH 20/34] update missing rotery export --- fast_llm/models/gpt/conversion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 7e4f4e11..2c432cbc 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -402,6 +402,7 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing elif type(rotary_config) is YarnRotaryConfig: rotary_scaling = { "rope_type": "yarn", + "factor": rotary_config.scale_factor, "attention_factor": rotary_config.attention_factor, "beta_fast": rotary_config.beta_fast, "beta_slow": rotary_config.beta_slow, @@ -433,6 +434,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A elif rotary_type == "yarn": rotary_config.update( { + "scale_factor": rope_scaling.get("factor", DEFAULT), "attention_factor": rope_scaling.get("attention_factor", DEFAULT), "beta_fast": rope_scaling.get("beta_fast", DEFAULT), "beta_slow": rope_scaling.get("beta_slow", DEFAULT), From 0306e36227801a7cee2711a0109c22dc55afb5ba Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Wed, 25 Jun 2025 16:17:59 +0000 Subject: [PATCH 21/34] reset attention_factor to old behaviour --- fast_llm/layers/transformer/rotary/config.py | 6 +++--- fast_llm/layers/transformer/rotary/rotary.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 033cbb35..8f2c9ab8 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -127,9 +127,9 @@ class YarnRotaryConfig(DefaultRotaryConfig): original_context_length: int = Field(default=8192, hint=FieldHint.feature) def _validate(self) -> None: - if self.attention_factor is None: - # with self._set_implicit_default(): - self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 + # if self.attention_factor is None: + # # with self._set_implicit_default(): + # self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4..867e1e33 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -181,7 +181,10 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): """ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: - return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor + attention_factor = self._config.attention_factor + if attention_factor is None: + attention_factor = 0.1 * math.log(self._config.scale_factor) + 1.0 + return super()._get_frequencies(sequence_length, kv_channels, device) * attention_factor def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) From 6bcb38d6d66c5ba9f4356e5999f57f40c343d4c1 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 27 Jun 2025 14:58:07 +0000 Subject: [PATCH 22/34] setting attention to _flash_attn_func --- fast_llm/layers/transformer/attention.py | 4 ++-- fast_llm/layers/transformer/config.py | 1 + fast_llm/models/gpt/model.py | 28 ++++++++++-------------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c990..ac42ecf8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -371,7 +371,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -381,7 +381,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index fac5a2ff..f38b6dba 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -82,6 +82,7 @@ class TransformerKwargs: sequence_length = "sequence_length" # TODO: Move grad_output = "grad_output" + causal = "causal" class TransformerLossNames: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bea192bc..777341cb 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -55,13 +55,15 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) - if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) - else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + if not self._config.transformer.diffusion: + if self._use_flash_attention: + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + else: + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -341,17 +343,9 @@ def preprocess( kwargs[LanguageModelKwargs.mask_probabilities] = batch.mask_probabilities.to( device=self._tensor_space.distributed.device ) - # Setup bidirection attention for diffusion should we set this in a preprocessor? BackupAttentionPreprocessor? - batch_size, seq_len = batch.token_ids.shape - attention_mask = torch.ones( - (batch_size, 1, seq_len, seq_len), - dtype=torch.bool, - device=self._tensor_space.distributed.device, - ) - kwargs[TransformerKwargs.attention_mask] = attention_mask - kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( - -10000.0, device=self._tensor_space.distributed.device - ) + # Setup bidirection attention for masked diffusion + # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. + kwargs[TransformerKwargs.causal] = False # set token ids to masked tokens batch.token_ids = batch.masked_token_ids From 093aa33ae20d0320e766234ea394b65be2ee2c21 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sat, 28 Jun 2025 01:37:03 +0000 Subject: [PATCH 23/34] debug --- fast_llm/layers/language_model/head.py | 2 ++ fast_llm/models/gpt/model.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 15af5822..29476597 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -427,6 +427,8 @@ def _logits_cross_entropy_forward_backward( dim=1, ).to(logits.dtype) + print(f"Loss weight: {loss_weight}") + loss, grad = cross_entropy_forward_backward( logits=logits.flatten(0, -2), target=target, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 777341cb..272c79f9 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -346,6 +346,8 @@ def preprocess( # Setup bidirection attention for masked diffusion # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. kwargs[TransformerKwargs.causal] = False + + print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") # set token ids to masked tokens batch.token_ids = batch.masked_token_ids From 141ed8810f0f9c31975302381f8bf91b3ac73e4c Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sat, 28 Jun 2025 14:23:47 +0000 Subject: [PATCH 24/34] avg only non-zero loss --- fast_llm/functional/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7fbf285c..4fe1a5f3 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -163,7 +163,8 @@ def _fused_cross_entropy_forward_backward( loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None # Avg across all the tokens. - return per_sample_loss.mean(), grad + denom = torch.clamp((loss_weight != 0).sum(), min=1) + return per_sample_loss.sum() / denom, grad _CROSS_ENTROPY_IMPLEMENTATIONS = { From 8bb00ed7fefa42e4285fa839ffd2a33f1320ac07 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sat, 28 Jun 2025 14:24:47 +0000 Subject: [PATCH 25/34] debug remove --- fast_llm/layers/language_model/head.py | 2 +- fast_llm/models/gpt/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 29476597..483483c7 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -427,7 +427,7 @@ def _logits_cross_entropy_forward_backward( dim=1, ).to(logits.dtype) - print(f"Loss weight: {loss_weight}") + # print(f"Loss weight: {loss_weight}") loss, grad = cross_entropy_forward_backward( logits=logits.flatten(0, -2), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 272c79f9..5403f41d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -346,8 +346,8 @@ def preprocess( # Setup bidirection attention for masked diffusion # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. kwargs[TransformerKwargs.causal] = False - - print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") + + # print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") # set token ids to masked tokens batch.token_ids = batch.masked_token_ids From 38737d4530e33e17f723e7ac729edf6b5b268b82 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 4 Jul 2025 01:38:44 +0000 Subject: [PATCH 26/34] remove non-zero weight --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 4fe1a5f3..00c66937 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -163,8 +163,8 @@ def _fused_cross_entropy_forward_backward( loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None # Avg across all the tokens. - denom = torch.clamp((loss_weight != 0).sum(), min=1) - return per_sample_loss.sum() / denom, grad + # denom = torch.clamp((loss_weight != 0).sum(), min=1) + return per_sample_loss.mean(), grad _CROSS_ENTROPY_IMPLEMENTATIONS = { From b043efeca6549680b5802c2906cbde58048da739 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 4 Jul 2025 16:13:57 +0000 Subject: [PATCH 27/34] revert to mean loss on all tokens --- fast_llm/functional/cross_entropy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 00c66937..ac6a36f0 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -162,8 +162,6 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = per_sample_loss * loss_weight.flatten() loss_weight_expanded = loss_weight.reshape(-1, 1) grad = grad * loss_weight_expanded if grad is not None else None - # Avg across all the tokens. - # denom = torch.clamp((loss_weight != 0).sum(), min=1) return per_sample_loss.mean(), grad From 0c221fd8318409d51dd5f18aec2d2cc12e96dde4 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 4 Jul 2025 16:27:14 +0000 Subject: [PATCH 28/34] tmp --- fast_llm/models/gpt/model.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 5403f41d..6bf8bdbd 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -346,6 +346,22 @@ def preprocess( # Setup bidirection attention for masked diffusion # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. kwargs[TransformerKwargs.causal] = False + # batch_size, seq_len = batch.token_ids.shape + # attention_mask = torch.ones( + # (batch_size, 1, seq_len, seq_len), + # dtype=torch.bool, + # device=self._tensor_space.distributed.device, + # ) + # kwargs[TransformerKwargs.attention_mask] = attention_mask + # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( + # # -10000.0, device=self._tensor_space.distributed.device + # # ) + # kwargs[TransformerKwargs.attention_mask_value] = torch.full( + # [], + # torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + # dtype=self._tensor_space.distributed_config.training_dtype.torch, + # device=self._tensor_space.distributed.device, + # ) # print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") From d29af35b6d3aeceeb60ea66a2ef592199d2a6b72 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 4 Jul 2025 18:38:52 +0000 Subject: [PATCH 29/34] adding fused attn --- fast_llm/layers/transformer/attention.py | 4 +++ fast_llm/models/gpt/model.py | 31 ++++++++++++------------ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index ac42ecf8..78c38b6d 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -182,7 +182,11 @@ def _attn_fused( ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) attn_weights = attn_weights.to(torch.float32) * self._layer_index + + attn_weights = attn_weights.transpose(2, 3) attn_weights = torch.where(mask, attn_weights, mask_value) + attn_weights = attn_weights.transpose(2, 3) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) with set_generator(self._tensor_space.distributed.tp_generator): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6bf8bdbd..6d420be6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -345,24 +345,25 @@ def preprocess( ) # Setup bidirection attention for masked diffusion # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. - kwargs[TransformerKwargs.causal] = False - # batch_size, seq_len = batch.token_ids.shape - # attention_mask = torch.ones( - # (batch_size, 1, seq_len, seq_len), - # dtype=torch.bool, - # device=self._tensor_space.distributed.device, - # ) - # kwargs[TransformerKwargs.attention_mask] = attention_mask + # kwargs[TransformerKwargs.causal] = False + batch_size, seq_len = batch.token_ids.shape + seq_len -= 1 # last token is dropped inputs + attention_mask = torch.ones( + (batch_size, 1, seq_len, seq_len), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( # # -10000.0, device=self._tensor_space.distributed.device # # ) - # kwargs[TransformerKwargs.attention_mask_value] = torch.full( - # [], - # torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, - # dtype=self._tensor_space.distributed_config.training_dtype.torch, - # device=self._tensor_space.distributed.device, - # ) - + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + # print(f"attention_mask : {attention_mask}") # print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") # set token ids to masked tokens From aa0d08cef499678f91108b6645c0be3e3ca9d560 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sun, 6 Jul 2025 02:36:29 +0000 Subject: [PATCH 30/34] include ar+masking --- fast_llm/config.py | 10 ++ fast_llm/data/data/gpt/data.py | 138 ++++++++++++++++++++++- fast_llm/data/dataset/gpt/config.py | 28 +++-- fast_llm/layers/language_model/head.py | 5 +- fast_llm/layers/transformer/config.py | 6 +- fast_llm/models/gpt/model.py | 147 +++++++++++++++++++------ 6 files changed, 287 insertions(+), 47 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index cdc1dd5d..eb4057a1 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1098,3 +1098,13 @@ def pop_nested_dict_value[ return d.pop(keys[-1]) else: return d.pop(keys) + + +class DiffusionStyle(str, enum.Enum): + """ + Type of diffusion masking to use. + """ + + masked = "masked" + ar_masked = "autoregressive_masked" + none = None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 896a2579..f5472fa1 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -9,6 +9,7 @@ import torch import torch.utils.data +from fast_llm.config import DiffusionStyle from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig @@ -37,6 +38,89 @@ class GPTBatch: mask_indexes: torch.Tensor | None = None mask_probabilities: torch.Tensor | None = None masked_token_ids: torch.Tensor | None = None + loss_weights: torch.Tensor | None = None + in_context_length: torch.Tensor | None = None + in_context: torch.Tensor | None = None + + +def do_mask(x, mask, mask_token_id): + x = x.clone() + x[mask] = mask_token_id + return x + + +def do_uniform(x, is_uniform, vocab_size): + # WARNING! "Shuffle" was really meant to mean "uniformly sample among all non-mask tokens" + x = x.clone() + uniform = torch.randint(0, vocab_size, x.size()) + x[is_uniform] = uniform[is_uniform] + return x + + +def prepare_batch( + data_ids, + positions, + padded, + mask_token_id, + vocab_size, + context_length, + p_mask, + *, + p_uniform=0.0, + ar_factor=1.0, + un_factor=1.0, + last_factor=0.0, + in_mask=None, + in_uniform=None, +): + + B, L = positions.size() + context_length = context_length.unsqueeze(1).expand(B, L) + p_mask = p_mask.unsqueeze(1) + + # Reminder: a context_length of zero still has one in_context token (à la ) + in_context = positions <= context_length + if in_mask is None: + in_mask = (~in_context) & (torch.rand(B, L) < p_mask) + + if in_uniform is None: + in_uniform = (~in_context) & (~in_mask) & (torch.rand(B, L) < p_uniform) + in_clean = (~in_context) & (~in_mask) & (~in_uniform) + + loss_weights = (~padded)[:, 1:] * torch.cat( + [ + ar_factor * in_context[:, 1:] + + in_mask[:, 1:] / p_mask + + un_factor * ((1 - p_uniform) * in_uniform[:, 1:] + p_uniform * in_clean[:, 1:]) / (1 - p_mask), + last_factor * torch.ones(B, 1), + ], + dim=1, + ) + + input_ids = do_uniform(do_mask(data_ids[:, :-1], in_mask, mask_token_id), in_uniform, vocab_size) + + # print( + # f"{'Name':<20} {'Shape/Value':<30}\n" + # f"{'-'*50}\n" + # f"{'input_ids':<20} {str(input_ids.shape):<30}\n" + # f"{'in_context':<20} {str(in_context.shape):<30}\n" + # f"{'in_mask':<20} {str(in_mask.shape):<30}\n" + # f"{'in_uniform':<20} {str(in_uniform.shape):<30}\n" + # f"{'in_clean':<20} {str(in_clean.shape):<30}\n" + # f"{'loss_weights':<20} {str(loss_weights.shape):<30}\n" + # f"{'in_context_length':<20} {str(context_length):<30}\n" + # ) + + return { + "in_context": in_context, # Only for tokens to be predicted + "in_mask": in_mask, + "in_uniform": in_uniform, + "in_clean": in_clean, + "input_ids": input_ids, + # "target_ids": data_ids, + "loss_weights": loss_weights, + # "in_context_length": context_length, + } def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -48,9 +132,13 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling mask_probabilities = None masked_token_ids = None + loss_weights = None + in_context_length = None + in_context = None + token_ids = torch.from_numpy(stacked_ids) - if sampling_parameters.diffusion.enabled: + if sampling_parameters.diffusion.style == DiffusionStyle.masked: diffusion_config = sampling_parameters.diffusion @@ -90,6 +178,51 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction. mask_probabilities = p_mask + elif sampling_parameters.diffusion.style == DiffusionStyle.ar_masked: + diffusion_config = sampling_parameters.diffusion + batch_size, seq_len = token_ids.shape + data_ids = token_ids + padded = torch.zeros_like(data_ids, dtype=torch.bool) + positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1) + + # TODO: + # 90% of the batch: C = random [0, seq_len // 4], 10%: C = random in [0, seq_len-2) + prob = torch.rand(1) + C = torch.where( + prob > diffusion_config.context_sampler, + torch.randint(0, seq_len // 4, (batch_size,), dtype=torch.long), + torch.randint(0, seq_len - 2, (batch_size,), dtype=torch.long), + ) + # C = torch.randint(0, (seq_len - 2), (batch_size,), dtype=torch.long) + # C = -torch.ones(batch_size, dtype=torch.int) + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + # Compute the mask probabilities for every sequence in the batch leaving extrams 0 & 1 + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon + + batch_data = prepare_batch( + data_ids=data_ids, + positions=positions, + padded=padded, + mask_token_id=diffusion_config.mask_token_id, + vocab_size=sampling_parameters.vocab_size, + context_length=C, + p_mask=p_mask, + p_uniform=0.0, # no uniform shuffling of tokens + ar_factor=diffusion_config.ar_factor, + un_factor=1.0, + last_factor=0.0, + ) + + # token_ids = batch_data["input_ids"] + masked_token_ids = batch_data["input_ids"] + + mask_indexes = batch_data["in_mask"] + # mask_probabilities = torch.full_like(mask_indexes, diffusion_config.max_mask_prob, dtype=token_ids.dtype) + loss_weights = batch_data["loss_weights"] + in_context_length = C + in_context = batch_data["in_context"] + if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] @@ -112,6 +245,9 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling mask_indexes=mask_indexes, mask_probabilities=mask_probabilities, masked_token_ids=masked_token_ids, + loss_weights=loss_weights, + in_context_length=in_context_length, + in_context=in_context, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 4b78ae34..180281fd 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,7 +8,16 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + DiffusionStyle, + Field, + FieldHint, + FieldUpdate, + check_field, + config_class, + skip_valid_if_none, +) from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, @@ -48,8 +57,8 @@ class ShufflingType(str, enum.Enum): class DiffusionMaskingConfig(Config): """Configuration for diffusion-based masking during data preparation.""" - enabled: bool = Field( - default=False, desc="Whether to use masked diffusion during training", hint=FieldHint.feature + style: DiffusionStyle = Field( + default=DiffusionStyle.none, desc="Whether to use masked diffusion during training", hint=FieldHint.feature ) epsilon: float = Field( @@ -68,14 +77,17 @@ class DiffusionMaskingConfig(Config): ) mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional) + ar_factor: float = Field( + default=1.0, + desc="Factor for the AR weigting on overal loss.", + hint=FieldHint.optional, + ) + context_sampler: float = Field( + default=1.0, desc="Context lenght C sampled in under 25% sequence length vs all", hint=FieldHint.optional + ) def _validate(self) -> None: super()._validate() - Assert.lt(self.epsilon, self.max_mask_prob) # , "epsilon must be less than max_mask_prob") - Assert.lt( - self.max_mask_prob, - 1.0, - ) @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 483483c7..c47795b8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,7 +4,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable +from fast_llm.config import Configurable, DiffusionStyle from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace @@ -377,7 +377,8 @@ def __init__( prediction_distance: int, ): super().__init__(config, tensor_space, prediction_distance) - self._loss_name = LanguageModelLossNames.mlm_loss + if config.transformer.diffusion is not None and config.transformer.diffusion == DiffusionStyle.masked: + self._loss_name = LanguageModelLossNames.mlm_loss def _logits_cross_entropy_forward_backward( self, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f38b6dba..7e0155f6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -6,7 +6,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import DiffusionStyle, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -486,8 +486,8 @@ class TransformerConfig(LLMBlockConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - diffusion: bool = Field( - default=False, + diffusion: DiffusionStyle = Field( + default=DiffusionStyle.none, desc="Use masked-diffusion for training.", hint=FieldHint.feature, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6d420be6..21fd84d6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,6 +3,7 @@ import torch +from fast_llm.config import DiffusionStyle from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor @@ -335,39 +336,119 @@ def preprocess( labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - if batch.mask_indexes is not None: - # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs - kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( - device=self._tensor_space.distributed.device - ) - kwargs[LanguageModelKwargs.mask_probabilities] = batch.mask_probabilities.to( - device=self._tensor_space.distributed.device - ) - # Setup bidirection attention for masked diffusion - # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. - # kwargs[TransformerKwargs.causal] = False - batch_size, seq_len = batch.token_ids.shape - seq_len -= 1 # last token is dropped inputs - attention_mask = torch.ones( - (batch_size, 1, seq_len, seq_len), - dtype=torch.bool, - device=self._tensor_space.distributed.device, - ) - kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) - # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( - # # -10000.0, device=self._tensor_space.distributed.device - # # ) - kwargs[TransformerKwargs.attention_mask_value] = torch.full( - [], - torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, - dtype=self._tensor_space.distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, - ) - # print(f"attention_mask : {attention_mask}") - # print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") - - # set token ids to masked tokens - batch.token_ids = batch.masked_token_ids + if self._config.transformer.diffusion is not None: + # if batch.mask_indexes is not None: + if self._config.transformer.diffusion == DiffusionStyle.masked: + assert batch.mask_indexes is not None, "masked-diffusion mode needs to set mask indexes" + assert batch.mask_probabilities is not None, "masked-diffusion mode needs to set mask" + + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + device=self._tensor_space.distributed.device + ) + kwargs[LanguageModelKwargs.mask_probabilities] = batch.mask_probabilities.to( + device=self._tensor_space.distributed.device + ) + # Setup bidirection attention for masked diffusion + # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. + # kwargs[TransformerKwargs.causal] = False + batch_size, seq_len = batch.token_ids.shape + seq_len -= 1 # last token is dropped inputs + attention_mask = torch.ones( + (batch_size, 1, seq_len, seq_len), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) + # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( + # # -10000.0, device=self._tensor_space.distributed.device + # # ) + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + # print(f"attention_mask : {attention_mask}") + # print(f"labels shape: {labels}, tokens: {batch.token_ids}, mask indexes shape: {batch.mask_indexes}") + + # set token ids to masked tokens + batch.token_ids = batch.masked_token_ids + + elif self._config.transformer.diffusion == DiffusionStyle.ar_masked: + + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.loss_weights] = batch.loss_weights.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.in_context] = batch.in_context.to( + device=self._tensor_space.distributed.device + ) + + # Setup bidirection attention for diffusion should we set this in a preprocessor? BackupAttentionPreprocessor? + # batch_size, seq_len = batch.token_ids.shape + # seq_len -= 1 # last token is drop from the input + # # Compute attention mask for diffusion + C = batch.in_context_length.to(device=self._tensor_space.distributed.device) + # row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, seq_len, 1) + # col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, 1, seq_len) + # C_exp = C.view(batch_size, 1, 1) + + # causal_mask = col_idx <= row_idx + # row_idx < C_exp + # col_idx < C_exp + + # attn_mask = torch.zeros( + # batch_size, seq_len, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device + # ) + + # for b in range(batch_size): + # C_val = C[b].item() + + # if C_val > 0: + # context_causal = causal_mask[0, :C_val, :C_val] + # attn_mask[b, :C_val, :C_val] = context_causal + + # if C_val > 0 and C_val < seq_len: + # attn_mask[b, C_val:, :C_val] = True + + # if C_val < seq_len: + # attn_mask[b, C_val:, C_val:] = True + + # Handle padding if needed + # if batch.sequence_lengths is not None: + # padded = torch.zeros( + # batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device + # ) + # for b in range(batch_size): + # padded[b, batch.sequence_lengths[b] :] = True + # not_padded = ~padded[:, 1:] + # attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2) + + # Reshape to match expected attention mask format + attention_mask = attn_mask.unsqueeze(1).unsqueeze(1) # Add additional dimension + # print(f"attention_mask shape: {attention_mask.shape}\n{attention_mask}") + kwargs[TransformerKwargs.attention_mask] = attention_mask + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + batch.token_ids = batch.masked_token_ids + # print(f"C: {C}") + # print(f"masked_token_ids: {batch.masked_token_ids}") + # print(f"token_ids: {batch.token_ids}") + # print(f"labels: {labels}") + # print(f"loss_weights: {batch.loss_weights}") + # print(f"mask indexes: {batch.mask_indexes}") + # print(f"in_context: {batch.in_context}") + # print(f"attention_mask: {attention_mask}") kwargs.update(reference_logits[i]) From 632dc7c608c1d3e8f13f0f7f8d3bde45834a41f7 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Sun, 6 Jul 2025 05:16:59 +0000 Subject: [PATCH 31/34] main update cr loss --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index ac6a36f0..ebbd21ad 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -147,7 +147,7 @@ def _fused_cross_entropy_forward_backward( else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask + per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_weight is None: if loss_mask is not None: From 0b469fb98ff4fb42582f9f340ec4d66a54909972 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Tue, 8 Jul 2025 00:42:41 +0000 Subject: [PATCH 32/34] include ar+diff option as a seperate style --- fast_llm/data/data/gpt/data.py | 1 - fast_llm/layers/language_model/config.py | 2 + fast_llm/layers/language_model/head.py | 97 +++++++++++++++--------- fast_llm/layers/transformer/attention.py | 26 +++++++ fast_llm/models/gpt/model.py | 63 ++++++++------- 5 files changed, 124 insertions(+), 65 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 076de9f5..d023ea8e 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -139,7 +139,6 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling token_ids = torch.from_numpy(stacked_ids) if sampling_parameters.diffusion.style == DiffusionStyle.masked: - diffusion_config = sampling_parameters.diffusion batch_size, seq_len = token_ids.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0a6f65dc..e0d22e54 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -42,6 +42,8 @@ class LanguageModelKwargs: mask_indexes = "mask_indexes" mask_probabilities = "mask_probabilities" mask_inputs = "mask_inputs" + loss_weights = "loss_weights" + in_context = "in_context" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 76467360..b8285b8c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -377,7 +377,7 @@ def __init__( prediction_distance: int, ): super().__init__(config, tensor_space, prediction_distance) - if config.transformer.diffusion is not None and config.transformer.diffusion == DiffusionStyle.masked: + if config.transformer.diffusion == DiffusionStyle.masked: self._loss_name = LanguageModelLossNames.mlm_loss def _logits_cross_entropy_forward_backward( @@ -402,44 +402,67 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._parallel_embeddings, ) - masked_indices = kwargs[LanguageModelKwargs.mask_indexes] - p_mask = kwargs[LanguageModelKwargs.mask_probabilities] - # index [0, 1, 2, 3, 4, 5] -> - # The labels are already left shifted x = [A, B, C, D, E, F] -> - # embd = [A, B, C, D, E] - # label = [B, C, D, E, F] - - # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, - # can it just learn to copy 3? i.e copy the next token to the masked? - # Yes. We need to drop those position from loss if the next token is not masked - # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) - - last_weight = 0 - B = logits.shape[0] - - loss_weight = torch.cat( - ( - # ar_weight * in_context[:, 1:] + # not implement yet - masked_indices[:, 1:] / p_mask[:, None], - # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet - (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), - # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later - ), - dim=1, - ).to(logits.dtype) + if self.config.transformer.diffusion == DiffusionStyle.masked: + masked_indices = kwargs[LanguageModelKwargs.mask_indexes] + p_mask = kwargs[LanguageModelKwargs.mask_probabilities] + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] + + # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, + # can it just learn to copy 3? i.e copy the next token to the masked? + # Yes. We need to drop those position from loss if the next token is not masked + # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) + + last_weight = 0 + B = logits.shape[0] + + loss_weight = torch.cat( + ( + # ar_weight * in_context[:, 1:] + # not implement yet + masked_indices[:, 1:] / p_mask[:, None], + # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet + (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), + # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later + ), + dim=1, + ).to(logits.dtype) + + # print(f"Loss weight: {loss_weight}") - # print(f"Loss weight: {loss_weight}") + loss, grad = cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, + grad_output=grad_output, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weight, + ) - loss, grad = cross_entropy_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=None, - grad_output=grad_output, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - loss_weight=loss_weight, - ) + elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked: + + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + context_index = kwargs[LanguageModelKwargs.in_context] + masked_index = kwargs[LanguageModelKwargs.mask_indexes] + B = loss_weights.shape[0] + masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + + loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward( + logits.flatten(0, -2), + target=target, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + + losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean()) + losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean()) # This happens with the loss_weight. # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 78c38b6d..7f6186c1 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -389,8 +389,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) + else: # TODO: Avoid the flattens. + input_ = self._attn_fused( query.flatten(-2), key.flatten(-2), @@ -398,6 +400,30 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask], kwargs[TransformerKwargs.attention_mask_value], ) + # print(f"Fused: Attention: {input_.shape} {input_} ") + + flash_input_ = _flash_attn_func( + query, + key, + value, + window_size=(-1, -1) if window_size is None else (window_size - 1, 0), + dropout_p=self._config.attention_dropout if self.training else 0.0, + causal=False, + softmax_scale=self._softmax_scale, + ) + # print(f"1: Flash : Attention: {flash_input_.shape} {flash_input_} ") + flash_input_ = flash_input_.flatten(-2) + # print(f"2: Flash: Attention: {flash_input_.shape} {flash_input_} ") + diff = input_ - flash_input_ + # print(f"Element-wise difference: {diff.shape} {diff}") + max_diff = diff.abs().max() + min_diff = diff.abs().min() + print(f"Min element-wise difference: {min_diff.item()}") + print(f"Max element-wise difference: {max_diff.item()}") + # if max_diff > 1e-3: + # print("Warning: Max difference exceeds 1e-3") + # import sys + # sys.exit(1) if self._debug_transformer: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0078eff5..a551b277 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -351,7 +351,8 @@ def preprocess( ) # Setup bidirection attention for masked diffusion # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. - # kwargs[TransformerKwargs.causal] = False + kwargs[TransformerKwargs.causal] = False + batch_size, seq_len = batch.token_ids.shape seq_len -= 1 # last token is dropped inputs attention_mask = torch.ones( @@ -395,40 +396,48 @@ def preprocess( # seq_len -= 1 # last token is drop from the input # # Compute attention mask for diffusion C = batch.in_context_length.to(device=self._tensor_space.distributed.device) - # row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, seq_len, 1) - # col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, 1, seq_len) - # C_exp = C.view(batch_size, 1, 1) + row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, seq_len, 1 + ) + col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, 1, seq_len + ) + C_exp = C.view(batch_size, 1, 1) - # causal_mask = col_idx <= row_idx - # row_idx < C_exp - # col_idx < C_exp + causal_mask = col_idx <= row_idx + row_idx < C_exp + col_idx < C_exp - # attn_mask = torch.zeros( - # batch_size, seq_len, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device - # ) + attn_mask = torch.zeros( + batch_size, + seq_len, + seq_len, + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) - # for b in range(batch_size): - # C_val = C[b].item() + for b in range(batch_size): + C_val = C[b].item() - # if C_val > 0: - # context_causal = causal_mask[0, :C_val, :C_val] - # attn_mask[b, :C_val, :C_val] = context_causal + if C_val > 0: + context_causal = causal_mask[0, :C_val, :C_val] + attn_mask[b, :C_val, :C_val] = context_causal - # if C_val > 0 and C_val < seq_len: - # attn_mask[b, C_val:, :C_val] = True + if C_val > 0 and C_val < seq_len: + attn_mask[b, C_val:, :C_val] = True - # if C_val < seq_len: - # attn_mask[b, C_val:, C_val:] = True + if C_val < seq_len: + attn_mask[b, C_val:, C_val:] = True # Handle padding if needed - # if batch.sequence_lengths is not None: - # padded = torch.zeros( - # batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device - # ) - # for b in range(batch_size): - # padded[b, batch.sequence_lengths[b] :] = True - # not_padded = ~padded[:, 1:] - # attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2) + if batch.sequence_lengths is not None: + padded = torch.zeros( + batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device + ) + for b in range(batch_size): + padded[b, batch.sequence_lengths[b] :] = True + not_padded = ~padded[:, 1:] + attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2) # Reshape to match expected attention mask format attention_mask = attn_mask.unsqueeze(1).unsqueeze(1) # Add additional dimension From 068138f08eae64c15cf435961190b61adeec92a0 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Tue, 8 Jul 2025 00:53:41 +0000 Subject: [PATCH 33/34] minor --- fast_llm/layers/language_model/head.py | 121 +++++++++++++------------ 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b8285b8c..fcd235cd 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -402,67 +402,68 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._parallel_embeddings, ) - if self.config.transformer.diffusion == DiffusionStyle.masked: - masked_indices = kwargs[LanguageModelKwargs.mask_indexes] - p_mask = kwargs[LanguageModelKwargs.mask_probabilities] - # index [0, 1, 2, 3, 4, 5] -> - # The labels are already left shifted x = [A, B, C, D, E, F] -> - # embd = [A, B, C, D, E] - # label = [B, C, D, E, F] - - # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, - # can it just learn to copy 3? i.e copy the next token to the masked? - # Yes. We need to drop those position from loss if the next token is not masked - # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) - - last_weight = 0 - B = logits.shape[0] - - loss_weight = torch.cat( - ( - # ar_weight * in_context[:, 1:] + # not implement yet - masked_indices[:, 1:] / p_mask[:, None], - # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet - (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), - # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later - ), - dim=1, - ).to(logits.dtype) - - # print(f"Loss weight: {loss_weight}") - - loss, grad = cross_entropy_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=None, - grad_output=grad_output, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - loss_weight=loss_weight, - ) - - elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked: - - loss_weights = kwargs[LanguageModelKwargs.loss_weights] - context_index = kwargs[LanguageModelKwargs.in_context] - masked_index = kwargs[LanguageModelKwargs.mask_indexes] - B = loss_weights.shape[0] - masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) - context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) - - loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward( - logits.flatten(0, -2), - target=target, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - loss_weight=loss_weights, - ) + if self.config.transformer.diffusion is not None: + if self.config.transformer.diffusion == DiffusionStyle.masked: + masked_indices = kwargs[LanguageModelKwargs.mask_indexes] + p_mask = kwargs[LanguageModelKwargs.mask_probabilities] + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] + + # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, + # can it just learn to copy 3? i.e copy the next token to the masked? + # Yes. We need to drop those position from loss if the next token is not masked + # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) + + last_weight = 0 + B = logits.shape[0] + + loss_weight = torch.cat( + ( + # ar_weight * in_context[:, 1:] + # not implement yet + masked_indices[:, 1:] / p_mask[:, None], + # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet + (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), + # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later + ), + dim=1, + ).to(logits.dtype) + + # print(f"Loss weight: {loss_weight}") + + loss, grad = cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, + grad_output=grad_output, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weight, + ) - losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean()) - losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean()) + elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked: + + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + context_index = kwargs[LanguageModelKwargs.in_context] + masked_index = kwargs[LanguageModelKwargs.mask_indexes] + B = loss_weights.shape[0] + masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + + loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward( + logits.flatten(0, -2), + target=target, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + # Add these before weighting to display them separately + losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean()) + losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean()) # This happens with the loss_weight. # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 From a573cd76e50f02b7c729a5f21915a12352945bd5 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Tue, 8 Jul 2025 13:40:11 +0000 Subject: [PATCH 34/34] attn verificiation checks --- fast_llm/layers/transformer/attention.py | 15 ++++----------- fast_llm/models/gpt/model.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 7f6186c1..3d78e1aa 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -182,11 +182,7 @@ def _attn_fused( ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) attn_weights = attn_weights.to(torch.float32) * self._layer_index - - attn_weights = attn_weights.transpose(2, 3) attn_weights = torch.where(mask, attn_weights, mask_value) - attn_weights = attn_weights.transpose(2, 3) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) with set_generator(self._tensor_space.distributed.tp_generator): @@ -417,13 +413,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ diff = input_ - flash_input_ # print(f"Element-wise difference: {diff.shape} {diff}") max_diff = diff.abs().max() - min_diff = diff.abs().min() - print(f"Min element-wise difference: {min_diff.item()}") - print(f"Max element-wise difference: {max_diff.item()}") - # if max_diff > 1e-3: - # print("Warning: Max difference exceeds 1e-3") - # import sys - # sys.exit(1) + + if max_diff > 1e-3: + print("Warning: Max difference exceeds 1e-3") + print(f"Max element-wise difference: {max_diff.item()}") if self._debug_transformer: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a551b277..4049bbe1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -57,7 +57,7 @@ def __init__( # TODO: Find a better solution. self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) - if not self._config.transformer.diffusion: + if self._config.transformer.diffusion is None: if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) else: @@ -355,12 +355,21 @@ def preprocess( batch_size, seq_len = batch.token_ids.shape seq_len -= 1 # last token is dropped inputs + # attention_mask = torch.ones( + # (batch_size, 1, seq_len, seq_len), + # dtype=torch.bool, + # device=self._tensor_space.distributed.device, + # ) + # kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) attention_mask = torch.ones( - (batch_size, 1, seq_len, seq_len), + (seq_len, seq_len), dtype=torch.bool, device=self._tensor_space.distributed.device, ) - kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) + kwargs[TransformerKwargs.attention_mask] = attention_mask[ + None, None, 0:seq_len, None, :seq_len + ] + print(f"attention_mask: {kwargs[TransformerKwargs.attention_mask]}") # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( # # -10000.0, device=self._tensor_space.distributed.device # # )