Skip to content

Masked Diffusion Training with Shift #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
db28a11
changes for basic LLaDA style diffusion masking support
gopeshh Apr 21, 2025
3d44671
tests for masking and MLM loss
gopeshh Apr 22, 2025
46dd535
temp fixes
nitsanluke Jun 4, 2025
aa8ab4d
tmp fix
nitsanluke Jun 4, 2025
9f348e7
including masked diffusion training setup
nitsanluke Jun 7, 2025
cdc9c96
adding weighted loss
nitsanluke Jun 11, 2025
d71e693
clean up
nitsanluke Jun 11, 2025
072e6c4
add loss weight
nitsanluke Jun 13, 2025
6127544
adding updates to p_mask
nitsanluke Jun 13, 2025
1cf15a8
update error mgs
nitsanluke Jun 16, 2025
f7a46d7
add comments and clean up
nitsanluke Jun 16, 2025
b80024e
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 16, 2025
01a683b
fx merge errors
nitsanluke Jun 18, 2025
ba913e1
fix merge issues
nitsanluke Jun 18, 2025
6c0c72d
register mask config
nitsanluke Jun 18, 2025
26aa13a
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 18, 2025
3245496
fx merge issues
nitsanluke Jun 18, 2025
5198310
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 23, 2025
4ad0bc1
fix labels
nitsanluke Jun 23, 2025
acacfe3
drop old tests
nitsanluke Jun 23, 2025
2a06ed4
tmp fix
nitsanluke Jun 24, 2025
dd68d28
fx tests
nitsanluke Jun 24, 2025
e0a7c80
update missing rotery export
nitsanluke Jun 25, 2025
0306e36
reset attention_factor to old behaviour
nitsanluke Jun 25, 2025
6bcb38d
setting attention to _flash_attn_func
nitsanluke Jun 27, 2025
093aa33
debug
nitsanluke Jun 28, 2025
141ed88
avg only non-zero loss
nitsanluke Jun 28, 2025
8bb00ed
debug remove
nitsanluke Jun 28, 2025
38737d4
remove non-zero weight
nitsanluke Jul 4, 2025
b043efe
revert to mean loss on all tokens
nitsanluke Jul 4, 2025
0c221fd
tmp
nitsanluke Jul 4, 2025
d29af35
adding fused attn
nitsanluke Jul 4, 2025
aa0d08c
include ar+masking
nitsanluke Jul 6, 2025
014b92e
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jul 6, 2025
632dc7c
main update cr loss
nitsanluke Jul 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,13 @@ def pop_nested_dict_value[
return d.pop(keys[-1])
else:
return d.pop(keys)


class DiffusionStyle(str, enum.Enum):
"""
Type of diffusion masking to use.
"""

masked = "masked"
ar_masked = "autoregressive_masked"
none = None
195 changes: 194 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.utils.data

from fast_llm.config import DiffusionStyle
from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
Expand All @@ -34,12 +35,197 @@ class GPTBatch:
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None
mask_indexes: torch.Tensor | None = None
mask_probabilities: torch.Tensor | None = None
masked_token_ids: torch.Tensor | None = None
loss_weights: torch.Tensor | None = None
in_context_length: torch.Tensor | None = None
in_context: torch.Tensor | None = None


def do_mask(x, mask, mask_token_id):
x = x.clone()
x[mask] = mask_token_id
return x


def do_uniform(x, is_uniform, vocab_size):
# WARNING! "Shuffle" was really meant to mean "uniformly sample among all non-mask tokens"
x = x.clone()
uniform = torch.randint(0, vocab_size, x.size())
x[is_uniform] = uniform[is_uniform]
return x


def prepare_batch(
data_ids,
positions,
padded,
mask_token_id,
vocab_size,
context_length,
p_mask,
*,
p_uniform=0.0,
ar_factor=1.0,
un_factor=1.0,
last_factor=0.0,
in_mask=None,
in_uniform=None,
):

B, L = positions.size()
context_length = context_length.unsqueeze(1).expand(B, L)
p_mask = p_mask.unsqueeze(1)

# Reminder: a context_length of zero still has one in_context token (Γ  la <BOS>)
in_context = positions <= context_length
if in_mask is None:
in_mask = (~in_context) & (torch.rand(B, L) < p_mask)

if in_uniform is None:
in_uniform = (~in_context) & (~in_mask) & (torch.rand(B, L) < p_uniform)
in_clean = (~in_context) & (~in_mask) & (~in_uniform)

loss_weights = (~padded)[:, 1:] * torch.cat(
[
ar_factor * in_context[:, 1:]
+ in_mask[:, 1:] / p_mask
+ un_factor * ((1 - p_uniform) * in_uniform[:, 1:] + p_uniform * in_clean[:, 1:]) / (1 - p_mask),
last_factor * torch.ones(B, 1),
],
dim=1,
)

input_ids = do_uniform(do_mask(data_ids[:, :-1], in_mask, mask_token_id), in_uniform, vocab_size)

# print(
# f"{'Name':<20} {'Shape/Value':<30}\n"
# f"{'-'*50}\n"
# f"{'input_ids':<20} {str(input_ids.shape):<30}\n"
# f"{'in_context':<20} {str(in_context.shape):<30}\n"
# f"{'in_mask':<20} {str(in_mask.shape):<30}\n"
# f"{'in_uniform':<20} {str(in_uniform.shape):<30}\n"
# f"{'in_clean':<20} {str(in_clean.shape):<30}\n"
# f"{'loss_weights':<20} {str(loss_weights.shape):<30}\n"
# f"{'in_context_length':<20} {str(context_length):<30}\n"
# )

return {
"in_context": in_context, # Only for tokens to be predicted
"in_mask": in_mask,
"in_uniform": in_uniform,
"in_clean": in_clean,
"input_ids": input_ids,
# "target_ids": data_ids,
"loss_weights": loss_weights,
# "in_context_length": context_length,
}


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:

stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
mask_indexes = None
mask_probabilities = None
masked_token_ids = None

loss_weights = None
in_context_length = None
in_context = None

token_ids = torch.from_numpy(stacked_ids)

if sampling_parameters.diffusion.style == DiffusionStyle.masked:

diffusion_config = sampling_parameters.diffusion

batch_size, seq_len = token_ids.shape
mask_token_id = diffusion_config.mask_token_id

# Generate a random tensor of batch size to seed masking probabilities
t = torch.rand((batch_size,))

# Compute the mask probabilities for every sequence in the batch
p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon

# Do we need to clamp at max_mask_prob?
# p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob))

# Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4]

# index [0, 1, 2, 3, 4, 5] ->
# The labels are already left shifted x = [A, B, C, D, E, F] ->
# embd = [A, B, C, D, E]
# label = [B, C, D, E, F]
# Last input token is dropped from the processing

# Generate random values for all tokens in the batch and only mask the positions\
# where the value is smaller than the mask probability
mask_indexes = torch.rand((batch_size, seq_len)) < p_mask[:, None]

# Need further classification of this padding - 1% data to have partial sequences and padding
# if diffusion_config.pad_prob > 0:
# pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob
# if pad_mask.any():
# mask_indexes[pad_mask] = True

# Replace masked tokens with the mask token ID to create input for the model.
masked_token_ids = torch.where(mask_indexes, mask_token_id, token_ids)

mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction.
mask_probabilities = p_mask

elif sampling_parameters.diffusion.style == DiffusionStyle.ar_masked:
diffusion_config = sampling_parameters.diffusion
batch_size, seq_len = token_ids.shape
data_ids = token_ids
padded = torch.zeros_like(data_ids, dtype=torch.bool)
positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1)

# TODO:
# 90% of the batch: C = random [0, seq_len // 4], 10%: C = random in [0, seq_len-2)
prob = torch.rand(1)
C = torch.where(
prob > diffusion_config.context_sampler,
torch.randint(0, seq_len // 4, (batch_size,), dtype=torch.long),
torch.randint(0, seq_len - 2, (batch_size,), dtype=torch.long),
)
# C = torch.randint(0, (seq_len - 2), (batch_size,), dtype=torch.long)
# C = -torch.ones(batch_size, dtype=torch.int)
# Generate a random tensor of batch size to seed masking probabilities
t = torch.rand((batch_size,))
# Compute the mask probabilities for every sequence in the batch leaving extrams 0 & 1
p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon

batch_data = prepare_batch(
data_ids=data_ids,
positions=positions,
padded=padded,
mask_token_id=diffusion_config.mask_token_id,
vocab_size=sampling_parameters.vocab_size,
context_length=C,
p_mask=p_mask,
p_uniform=0.0, # no uniform shuffling of tokens
ar_factor=diffusion_config.ar_factor,
un_factor=1.0,
last_factor=0.0,
)

# token_ids = batch_data["input_ids"]
masked_token_ids = batch_data["input_ids"]

mask_indexes = batch_data["in_mask"]
# mask_probabilities = torch.full_like(mask_indexes, diffusion_config.max_mask_prob, dtype=token_ids.dtype)
loss_weights = batch_data["loss_weights"]
in_context_length = C
in_context = batch_data["in_context"]

if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]

stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
Expand All @@ -49,12 +235,19 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
token_ids=token_ids,
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
mask_indexes=mask_indexes,
mask_probabilities=mask_probabilities,
masked_token_ids=masked_token_ids,
loss_weights=loss_weights,
in_context_length=in_context_length,
in_context=in_context,
)


Expand Down
53 changes: 52 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@

import yaml

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.config import (
Config,
DiffusionStyle,
Field,
FieldHint,
FieldUpdate,
check_field,
config_class,
skip_valid_if_none,
)
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.dataset.config import (
BlendedDatasetConfig,
Expand Down Expand Up @@ -44,6 +53,43 @@ class ShufflingType(str, enum.Enum):
legacy = "legacy"


@config_class(registry=True)
class DiffusionMaskingConfig(Config):
"""Configuration for diffusion-based masking during data preparation."""

style: DiffusionStyle = Field(
default=DiffusionStyle.none, desc="Whether to use masked diffusion during training", hint=FieldHint.feature
)

epsilon: float = Field(
default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0)
)

max_mask_prob: float = Field(
default=0.15, desc="Maximum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0)
)

pad_prob: float = Field(
default=0.01,
desc="Probability of padding tokens for 1% of samples",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional)
ar_factor: float = Field(
default=1.0,
desc="Factor for the AR weigting on overal loss.",
hint=FieldHint.optional,
)
context_sampler: float = Field(
default=1.0, desc="Context lenght C sampled in under 25% sequence length vs all", hint=FieldHint.optional
)

def _validate(self) -> None:
super()._validate()


@config_class()
class GPTSamplingConfig(SamplingConfig):
"""
Expand All @@ -62,6 +108,10 @@ class GPTSamplingConfig(SamplingConfig):
desc="Shuffling strategy.",
hint=FieldHint.feature,
)
diffusion: DiffusionMaskingConfig = Field(
desc="Configuration for diffusion-based masking during data preparation.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
Expand All @@ -79,6 +129,7 @@ class GPTSamplingParameters(SamplingParameters):
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
diffusion: DiffusionMaskingConfig


@dataclasses.dataclass(kw_only=True)
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
# The name (dict key) is used to insert the weight in the kwargs of the forward pass.
return {}

@property
@abc.abstractmethod
def loss_defs(self) -> list[LossDef]:
def get_loss_defs(self) -> list[LossDef]:
pass

def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def setup(
phase=PhaseType.validation,
)

self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()
self._evaluation_iterator = None
self._is_setup = True

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self._stages: list[Stage] = self._multi_stage.stages
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, config: TrainerConfig):
multi_stage=self._multi_stage,
distributed_config=self._config.model.distributed,
)
self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()

if not self._is_evaluation_only:
steps_per_split = {
Expand Down
Loading
Loading