diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501b..4740bd03 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1099,3 +1099,13 @@ def pop_nested_dict_value[ return d.pop(keys[-1]) else: return d.pop(keys) + + +class DiffusionStyle(str, enum.Enum): + """ + Type of diffusion masking to use. + """ + + masked = "masked" + ar_masked = "autoregressive_masked" + none = None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb5..3d5f990b 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -9,6 +9,7 @@ import torch import torch.utils.data +from fast_llm.config import DiffusionStyle from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig @@ -34,12 +35,222 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None + mask_indexes: torch.Tensor | None = None + mask_probabilities: torch.Tensor | None = None + masked_token_ids: torch.Tensor | None = None + loss_weights: torch.Tensor | None = None + in_context_length: torch.Tensor | None = None + in_context: torch.Tensor | None = None + + +def do_mask(x, mask, mask_token_id): + x = x.clone() + x[mask] = mask_token_id + return x + + +def do_uniform(x, is_uniform, vocab_size): + # WARNING! "Shuffle" was really meant to mean "uniformly sample among all non-mask tokens" + x = x.clone() + uniform = torch.randint(0, vocab_size, x.size()) + x[is_uniform] = uniform[is_uniform] + return x + + +def prepare_batch( + data_ids, + positions, + padded, + mask_token_id, + vocab_size, + context_length, + p_mask, + *, + p_uniform=0.0, + ar_factor=1.0, + un_factor=1.0, + last_factor=0.0, + in_mask=None, + in_uniform=None, +): + + B, L = positions.size() + context_length = context_length.unsqueeze(1).expand(B, L) + p_mask = p_mask.unsqueeze(1) + + # Reminder: a context_length of zero still has one in_context token (à la ) + in_context = positions <= context_length + if in_mask is None: + in_mask = (~in_context) & (torch.rand(B, L) < p_mask) + + if in_uniform is None: + in_uniform = (~in_context) & (~in_mask) & (torch.rand(B, L) < p_uniform) + in_clean = (~in_context) & (~in_mask) & (~in_uniform) + + loss_weights = (~padded)[:, 1:] * torch.cat( + [ + ar_factor * in_context[:, 1:] + + in_mask[:, 1:] / p_mask + + un_factor * ((1 - p_uniform) * in_uniform[:, 1:] + p_uniform * in_clean[:, 1:]) / (1 - p_mask), + last_factor * torch.ones(B, 1), + ], + dim=1, + ) + + input_ids = do_uniform(do_mask(data_ids[:, :-1], in_mask, mask_token_id), in_uniform, vocab_size) + + # print( + # f"{'Name':<20} {'Shape/Value':<30}\n" + # f"{'-'*50}\n" + # f"{'input_ids':<20} {str(input_ids.shape):<30}\n" + # f"{'in_context':<20} {str(in_context.shape):<30}\n" + # f"{'in_mask':<20} {str(in_mask.shape):<30}\n" + # f"{'in_uniform':<20} {str(in_uniform.shape):<30}\n" + # f"{'in_clean':<20} {str(in_clean.shape):<30}\n" + # f"{'loss_weights':<20} {str(loss_weights.shape):<30}\n" + # f"{'in_context_length':<20} {str(context_length):<30}\n" + # ) + + return { + "in_context": in_context, # Only for tokens to be predicted + "in_mask": in_mask, + "in_uniform": in_uniform, + "in_clean": in_clean, + "input_ids": input_ids, + # "target_ids": data_ids, + "loss_weights": loss_weights, + # "in_context_length": context_length, + } def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: + stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + mask_indexes = None + mask_probabilities = None + masked_token_ids = None + + loss_weights = None + in_context_length = None + in_context = None + + token_ids = torch.from_numpy(stacked_ids) + + if sampling_parameters.diffusion.style == DiffusionStyle.masked: + diffusion_config = sampling_parameters.diffusion + + batch_size, seq_len = token_ids.shape + diffusion_config.mask_token_id + positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1) + padded = torch.zeros_like(token_ids, dtype=torch.bool) + + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + + # Compute the mask probabilities for every sequence in the batch + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon + + # Do we need to clamp at max_mask_prob? + # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) + + # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] + + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] + # Last input token is dropped from the processing + + # Generate random values for all tokens in the batch and only mask the positions\ + # where the value is smaller than the mask probability + # mask_indexes = torch.rand((batch_size, seq_len)) < p_mask[:, None] + + # Need further classification of this padding - 1% data to have partial sequences and padding + # if diffusion_config.pad_prob > 0: + # pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob + # if pad_mask.any(): + # mask_indexes[pad_mask] = True + + # Replace masked tokens with the mask token ID to create input for the model. + # masked_token_ids = torch.where(mask_indexes, mask_token_id, token_ids) + # masked_token_ids = masked_token_ids[:, :-1] # Remove the last token, which is not used for prediction. + + # mask_indexes = mask_indexes[:, 1:] # Shift left so previous token to mask is the index for loss. + # mask_probabilities = p_mask + + batch_data = prepare_batch( + data_ids=token_ids, + positions=positions, + padded=padded, + mask_token_id=diffusion_config.mask_token_id, + vocab_size=sampling_parameters.vocab_size, + context_length=-torch.ones(batch_size, dtype=torch.int), # No context length for masked diffusion + p_mask=p_mask, + p_uniform=0.0, # no uniform shuffling of tokens + ar_factor=0.0, + un_factor=0.0, + last_factor=0.0, + ) + + # token_ids = batch_data["input_ids"] + masked_token_ids = batch_data["input_ids"] + + mask_indexes = batch_data["in_mask"] + # mask_probabilities = torch.full_like(mask_indexes, diffusion_config.max_mask_prob, dtype=token_ids.dtype) + loss_weights = batch_data["loss_weights"] + # in_context_length = C + in_context = batch_data["in_context"] + + elif sampling_parameters.diffusion.style == DiffusionStyle.ar_masked: + diffusion_config = sampling_parameters.diffusion + batch_size, seq_len = token_ids.shape + data_ids = token_ids + padded = torch.zeros_like(data_ids, dtype=torch.bool) + positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1) + + # TODO: + # 90% of the batch: C = random [0, seq_len // 4], 10%: C = random in [0, seq_len-2) + prob = torch.rand(1) + C = torch.where( + prob > diffusion_config.context_sampler, + torch.randint(0, seq_len // 4, (batch_size,), dtype=torch.long), + torch.randint(0, seq_len - 2, (batch_size,), dtype=torch.long), + ) + # C = torch.randint(0, (seq_len - 2), (batch_size,), dtype=torch.long) + # C = -torch.ones(batch_size, dtype=torch.int) + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + # Compute the mask probabilities for every sequence in the batch leaving extrams 0 & 1 + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon + + batch_data = prepare_batch( + data_ids=data_ids, + positions=positions, + padded=padded, + mask_token_id=diffusion_config.mask_token_id, + vocab_size=sampling_parameters.vocab_size, + context_length=C, + p_mask=p_mask, + p_uniform=0.0, # no uniform shuffling of tokens + ar_factor=diffusion_config.ar_factor, + un_factor=1.0, + last_factor=0.0, + ) + + # token_ids = batch_data["input_ids"] + masked_token_ids = batch_data["input_ids"] + + mask_indexes = batch_data["in_mask"] + # mask_probabilities = torch.full_like(mask_indexes, diffusion_config.max_mask_prob, dtype=token_ids.dtype) + loss_weights = batch_data["loss_weights"] + in_context_length = C + in_context = batch_data["in_context"] + + if sampling_parameters.use_loss_masking_spans: + stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + stacked_chosen_spans = None stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: @@ -49,12 +260,19 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), + token_ids=token_ids, loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + mask_indexes=mask_indexes, + mask_probabilities=mask_probabilities, + masked_token_ids=masked_token_ids, + loss_weights=loss_weights, + in_context_length=in_context_length, + in_context=in_context, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc..6012b78b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,7 +8,16 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + DiffusionStyle, + Field, + FieldHint, + FieldUpdate, + check_field, + config_class, + skip_valid_if_none, +) from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, @@ -44,6 +53,43 @@ class ShufflingType(str, enum.Enum): legacy = "legacy" +@config_class(registry=True) +class DiffusionMaskingConfig(Config): + """Configuration for diffusion-based masking during data preparation.""" + + style: DiffusionStyle = Field( + default=DiffusionStyle.none, desc="Whether to use masked diffusion during training", hint=FieldHint.feature + ) + + epsilon: float = Field( + default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) + ) + + max_mask_prob: float = Field( + default=0.15, desc="Maximum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) + ) + + pad_prob: float = Field( + default=0.01, + desc="Probability of padding tokens for 1% of samples", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional) + ar_factor: float = Field( + default=1.0, + desc="Factor for the AR weigting on overal loss.", + hint=FieldHint.optional, + ) + context_sampler: float = Field( + default=1.0, desc="Context lenght C sampled in under 25% sequence length vs all", hint=FieldHint.optional + ) + + def _validate(self) -> None: + super()._validate() + + @config_class() class GPTSamplingConfig(SamplingConfig): """ @@ -62,6 +108,10 @@ class GPTSamplingConfig(SamplingConfig): desc="Shuffling strategy.", hint=FieldHint.feature, ) + diffusion: DiffusionMaskingConfig = Field( + desc="Configuration for diffusion-based masking during data preparation.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -79,6 +129,7 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 + diffusion: DiffusionMaskingConfig @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index df603a91..51252f03 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -137,9 +137,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - @property @abc.abstractmethod - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: pass def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 78aad230..96e71e94 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -119,7 +119,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559..f2b302de 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,7 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 766398d0..4a6ae161 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -152,7 +152,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() if not self._is_evaluation_only: steps_per_split = { diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec..3b3d72e6 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -15,6 +15,7 @@ def _torch_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -22,6 +23,8 @@ def _torch_cross_entropy_forward_backward( and separate forward and backward kernels lead to poor performance. TODO: loss masking only works for with labels format and if the masking index is set to -100. """ + assert loss_weight is None, "Loss weight not supported in torch cross-entropy implementation." + # Torch compile doesn't understand this. with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) @@ -82,6 +85,7 @@ def _fused_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None, group: ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -144,14 +148,25 @@ def _fused_cross_entropy_forward_backward( predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) per_sample_loss = sum_exp_logits.log() - predicted_logits - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + # print(f"Per sample loss {per_sample_loss} {per_sample_loss.shape}") - loss = per_sample_loss.mean() - if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if loss_weight is None: + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask - return loss, grad + loss = per_sample_loss.mean() + if target_format != TargetFormat.labels and group is not None: + all_reduce(loss, op=ReduceOp.MEAN, group=group) + return loss, grad + else: + # Weight every token loss by the loss weight. Before averaging. + per_sample_loss = per_sample_loss * loss_weight.view(-1, 1) + loss_weight_expanded = loss_weight.reshape(-1, 1) + grad = grad * loss_weight_expanded if grad is not None else None + # print(f"Loss {per_sample_loss} {per_sample_loss.shape}") + denom = torch.clamp((loss_weight != 0).sum(), min=1) + # print(f"avg all: {per_sample_loss.mean()}") + return per_sample_loss.sum() / denom, grad _CROSS_ENTROPY_IMPLEMENTATIONS = { @@ -170,6 +185,7 @@ def cross_entropy_forward_backward( implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -177,6 +193,7 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) @@ -193,5 +210,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, loss_weight=loss_weight ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 8cb59c85..c257edbb 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -125,6 +125,7 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -133,6 +134,8 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + assert loss_weight is None, "Loss weight not supported in triton cross-entropy implementation." + # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c4776abe..e0d22e54 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -22,6 +22,7 @@ class LanguageModelDimNames: class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" + mlm_loss = "masked_language_model_loss" @staticmethod def multi_token_prediction_loss(index: int) -> str: @@ -38,7 +39,11 @@ class LanguageModelKwargs: chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" loss_mask = "loss_mask" + mask_indexes = "mask_indexes" + mask_probabilities = "mask_probabilities" mask_inputs = "mask_inputs" + loss_weights = "loss_weights" + in_context = "in_context" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 88b0612b..4d42328b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,7 +4,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable +from fast_llm.config import Configurable, DiffusionStyle from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace @@ -363,3 +363,115 @@ def _logits_cross_entropy_forward_backward( # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) if self.training else None + + +class MLMHead(LanguageModelHead): + """ + A masked language model head for diffusion-based training.` + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + prediction_distance: int, + ): + super().__init__(config, tensor_space, prediction_distance) + if config.transformer.diffusion == DiffusionStyle.masked: + self._loss_name = LanguageModelLossNames.mlm_loss + + def _logits_cross_entropy_forward_backward( + self, + input_: torch.Tensor, + target: torch.Tensor | None, + loss_mask: torch.Tensor | None, + weight: torch.Tensor, + grad_output: float, + kwargs: dict, + losses: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + assert target is not None, "MLM head requires target labels" + assert loss_mask is None, "MLM head does not support loss mask" + + logits, context = output_parallel_linear_forward( + input_=input_, + weight=weight, + bias=None, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + ) + + if self.config.transformer.diffusion is not None: + if self.config.transformer.diffusion == DiffusionStyle.masked: + # masked_indices = kwargs[LanguageModelKwargs.mask_indexes] + # p_mask = kwargs[LanguageModelKwargs.mask_probabilities] + # # index [0, 1, 2, 3, 4, 5] -> + # # The labels are already left shifted x = [A, B, C, D, E, F] -> + # # embd = [A, B, C, D, E] + # # label = [B, C, D, E, F] + + # # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model, + # # can it just learn to copy 3? i.e copy the next token to the masked? + # # Yes. We need to drop those position from loss if the next token is not masked + # # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama) + + # last_weight = 0 + # B = logits.shape[0] + + # loss_weight = torch.cat( + # ( + # # ar_weight * in_context[:, 1:] + # not implement yet + # masked_indices[:, :-1] / p_mask[:, None], + # # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet + # (last_weight * torch.ones(B, device=logits.device)).unsqueeze(1), + # # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later + # ), + # dim=1, + # ).to(logits.dtype) + + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + # print(f"Loss weight: {loss_weights} {loss_weights.shape} ") + + loss, grad = cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, + grad_output=grad_output, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + + elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked: + + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + context_index = kwargs[LanguageModelKwargs.in_context] + masked_index = kwargs[LanguageModelKwargs.mask_indexes] + B = loss_weights.shape[0] + masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + + loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward( + logits.flatten(0, -2), + target=target, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + # Add these before weighting to display them separately + losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean()) + losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean()) + + # This happens with the loss_weight. + # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274 + + # Compute per token loss by avg across all tokens in the batch (tokens we ignore are assumed to have a 0 loss still counted towards the average) + # done inside the cross-entropy function + # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275 + + del logits + return loss, output_parallel_linear_backward(grad, context) if self.training else None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c990..c6f43a86 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -371,7 +371,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -381,10 +381,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) + else: # TODO: Avoid the flattens. input_ = self._attn_fused( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf589..7e0155f6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -6,7 +6,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import DiffusionStyle, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -82,6 +82,7 @@ class TransformerKwargs: sequence_length = "sequence_length" # TODO: Move grad_output = "grad_output" + causal = "causal" class TransformerLossNames: @@ -485,6 +486,11 @@ class TransformerConfig(LLMBlockConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + diffusion: DiffusionStyle = Field( + default=DiffusionStyle.none, + desc="Use masked-diffusion for training.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index ce7af88d..8f2c9ab8 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -112,7 +112,7 @@ class YarnRotaryConfig(DefaultRotaryConfig): # TODO: Add descriptions. scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - attention_factor: None | float = Field( + attention_factor: float | None = Field( default=None, hint=FieldHint.feature, ) @@ -127,9 +127,9 @@ class YarnRotaryConfig(DefaultRotaryConfig): original_context_length: int = Field(default=8192, hint=FieldHint.feature) def _validate(self) -> None: - if self.attention_factor is None: - with self._set_implicit_default(): - self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 + # if self.attention_factor is None: + # # with self._set_implicit_default(): + # self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4..867e1e33 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -181,7 +181,10 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): """ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: - return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor + attention_factor = self._config.attention_factor + if attention_factor is None: + attention_factor = 0.1 * math.log(self._config.scale_factor) + 1.0 + return super()._get_frequencies(sequence_length, kv_channels, device) * attention_factor def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef40..8b1f1635 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -59,10 +59,9 @@ def preprocess( # TODO: Adjust or reimplement. return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - @property - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: # TODO: Adjust or reimplement. - return super().loss_defs + return super().get_loss_defs() class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786..dbd74e37 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -404,6 +404,7 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing elif type(rotary_config) is YarnRotaryConfig: rotary_scaling = { "rope_type": "yarn", + "factor": rotary_config.scale_factor, "attention_factor": rotary_config.attention_factor, "beta_fast": rotary_config.beta_fast, "beta_slow": rotary_config.beta_slow, @@ -435,6 +436,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A elif rotary_type == "yarn": rotary_config.update( { + "scale_factor": rope_scaling.get("factor", DEFAULT), "attention_factor": rope_scaling.get("attention_factor", DEFAULT), "beta_fast": rope_scaling.get("beta_fast", DEFAULT), "beta_slow": rope_scaling.get("beta_slow", DEFAULT), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 00b4ee27..299ec8d9 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,6 +3,7 @@ import torch +from fast_llm.config import DiffusionStyle from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor @@ -12,7 +13,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead, MLMHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, @@ -55,13 +56,15 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) - if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) - else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + if self._config.transformer.diffusion is None: + if self._use_flash_attention: + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + else: + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -78,13 +81,22 @@ def get_output_layers(self) -> list[Layer]: return_input=i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, + if self._config.transformer.diffusion: + layers.append( + MLMHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) + ) + else: + layers.append( + LanguageModelHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) ) - ) return layers def get_layers(self) -> list[Layer]: @@ -323,6 +335,155 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels + + if self._config.transformer.diffusion is not None: + # if batch.mask_indexes is not None: + assert batch.loss_weights is not None, "masked-diffusion mode needs to set loss_weights" + if self._config.transformer.diffusion == DiffusionStyle.masked: + # assert batch.loss_weights is not None, "masked-diffusion mode needs to set loss_weights" + + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + # kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + # device=self._tensor_space.distributed.device + # ) + # kwargs[LanguageModelKwargs.mask_probabilities] = batch.mask_probabilities.to( + # device=self._tensor_space.distributed.device + # ) + # Setup bidirection attention for masked diffusion + # It uses _flash_attn_func so no need to set attention_mask and attention_mask_value. + kwargs[TransformerKwargs.causal] = False + kwargs[LanguageModelKwargs.loss_weights] = batch.loss_weights.to( + device=self._tensor_space.distributed.device, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + ) + + batch_size, seq_len = batch.token_ids.shape + seq_len -= 1 # last token is dropped inputs + # seq_len = kwargs[TransformerKwargs.sequence_length] # alrenatively we can use this + # attention_mask = torch.ones( + # (batch_size, 1, seq_len, seq_len), + # dtype=torch.bool, + # device=self._tensor_space.distributed.device, + # ) + # kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1) + attention_mask = torch.ones( + (seq_len, seq_len), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[TransformerKwargs.attention_mask] = attention_mask[ + None, None, 0:seq_len, None, :seq_len + ] + # alternatively we can use this + # sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + # sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + # kwargs[TransformerKwargs.attention_mask] = self._mask[ + # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + # ] + # print(f"attention_mask: {kwargs[TransformerKwargs.attention_mask]}") + # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor( + # # -10000.0, device=self._tensor_space.distributed.device + # # ) + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + # print(f"attention_mask : {attention_mask}") + # print(f"labels shape: {labels}\ntokens: {batch.token_ids}\nmask indexes shape: {batch.mask_indexes}\nmask input: {batch.masked_token_ids}") + + # set token ids to masked tokens + batch.token_ids = batch.masked_token_ids.to( + device=self._tensor_space.distributed.device, + dtype=torch.int64, + non_blocking=True, + ) + tokens = batch.token_ids + + elif self._config.transformer.diffusion == DiffusionStyle.ar_masked: + + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.loss_weights] = batch.loss_weights.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.in_context] = batch.in_context.to( + device=self._tensor_space.distributed.device + ) + + # Setup bidirection attention for diffusion should we set this in a preprocessor? BackupAttentionPreprocessor? + # batch_size, seq_len = batch.token_ids.shape + # seq_len -= 1 # last token is drop from the input + # # Compute attention mask for diffusion + C = batch.in_context_length.to(device=self._tensor_space.distributed.device) + row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, seq_len, 1 + ) + col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, 1, seq_len + ) + C_exp = C.view(batch_size, 1, 1) + + causal_mask = col_idx <= row_idx + row_idx < C_exp + col_idx < C_exp + + attn_mask = torch.zeros( + batch_size, + seq_len, + seq_len, + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + + for b in range(batch_size): + C_val = C[b].item() + + if C_val > 0: + context_causal = causal_mask[0, :C_val, :C_val] + attn_mask[b, :C_val, :C_val] = context_causal + + if C_val > 0 and C_val < seq_len: + attn_mask[b, C_val:, :C_val] = True + + if C_val < seq_len: + attn_mask[b, C_val:, C_val:] = True + + # Handle padding if needed + if batch.sequence_lengths is not None: + padded = torch.zeros( + batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device + ) + for b in range(batch_size): + padded[b, batch.sequence_lengths[b] :] = True + not_padded = ~padded[:, 1:] + attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2) + + # Reshape to match expected attention mask format + attention_mask = attn_mask.unsqueeze(1).unsqueeze(1) # Add additional dimension + # print(f"attention_mask shape: {attention_mask.shape}\n{attention_mask}") + kwargs[TransformerKwargs.attention_mask] = attention_mask + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + batch.token_ids = batch.masked_token_ids + # print(f"C: {C}") + # print(f"masked_token_ids: {batch.masked_token_ids}") + # print(f"token_ids: {batch.token_ids}") + # print(f"labels: {labels}") + # print(f"loss_weights: {batch.loss_weights}") + # print(f"mask indexes: {batch.mask_indexes}") + # print(f"in_context: {batch.in_context}") + # print(f"attention_mask: {attention_mask}") + kwargs.update(reference_logits[i]) for preprocessor in self._preprocessors: @@ -365,8 +526,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - @property - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: loss_defs = [] if ( self._config.transformer.num_experts > 1 @@ -390,6 +550,10 @@ def loss_defs(self) -> list[LossDef]: if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + if self._config.transformer.diffusion: + # Masked LM Loss for masked-diffusion training + loss_defs.append(LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1)) + for i in range(self._config.prediction_heads): loss_defs.append( LossDef( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e..9003b723 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -31,6 +31,7 @@ def _get_sampling_parameters( "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.prediction_heads, + "diffusion": self._config.data.sampling.diffusion, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters)