Skip to content

Conversation

@gopeshh
Copy link
Collaborator

@gopeshh gopeshh commented Apr 21, 2025

✨ Description

Cleaned up the code a bit:

  1. Added Diffusion config object as we discussed
  2. removed noise schedules for v1
  3. Moved loss calculation to head.py (as I noticed language modelling loss is computed there)
  4. Moved bidirectional attention to preprocessing.py file as it seems like the attention mask is computed there

Of course still a WIP but feel free to leave comments and suggestions

These are changes to address this PR: #208 (comment)

@gopeshh gopeshh requested a review from tscholak April 21, 2025 12:12
Copy link

@PierreAndreNoel PierreAndreNoel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just quick feedback as I am very busy with other things, but please remind me to come back here next week and I'll dig deeper in.


t = torch.rand(batch_size, device=device)

p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions/thoughts (I am just browsing quickly, and I am not looking at the paper right now):

  • Why is the lower bound epsilon and the upper bound max_mask_prob?
  • My guts tell me you never want the mask probability to be exactly 1, for the same kind of reasons you don't want it to be exactly 0.
  • This approach using torch.min will put a discrete probability for p_mask to be exactly max_mask_prob.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying this coz we could have many timesteps with the exact masking level set to max_mask_prob? So are you suggesting some soft clipping instead of a hard upper bound?


masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask

if diffusion_config.pad_prob > 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta: I currently can't comment about padding; it will have to wait for next week, as I need to re-read the paper better (our own work doesn't do padding).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to include variable length sequences for 1% of the data?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly!

p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob))
p_mask = p_mask[:, None].expand(-1, seq_len)

masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming True means "masked".

attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool)
else:
# Causal attention
attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that you never want such a triangular causal attention, as this would give a strictly worse model than an autoregressive model.

Suppose that, at inference, tokens are unmasked in the order (4, 2, 3, 0, 1). Token 4 is unmasked first, but this triangular matrix prevents all other tokens from ever "seeing" it.

What is the closest case that makes sense to me would be to permute the rows and columns of the triangular matrix using (4,2,3,0,1), so that token 2 can see token 4, token 3 can see tokens 2 and 4, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, permuted rows and columns makes sense - so we can preserve the order in which it was unmasked. I will update this.
I guess this idea is similar to this paper? https://arxiv.org/abs/1906.08237

kwargs['masked_indices'] = masked_indices
kwargs['p_mask'] = p_mask

if self._config.diffusion.bidirectional_attention:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want a string instead of a boolean, as there are many possible attention choices (e.g., blocks) that may come up. Also see the next comment below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, will change this!

masked_p = p_mask[masked_indices]

# Compute MLM loss
loss, grad = cross_entropy_forward_backward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlamypoirier
Copy link
Collaborator

Is this still relevant?

@jlamypoirier jlamypoirier changed the title Changes for basic LLaDA style diffusion masking support [Inactive] Changes for basic LLaDA style diffusion masking support Jun 16, 2025
@gopeshh gopeshh closed this Jun 22, 2025
@gopeshh gopeshh deleted the gopeshh/masked_diffusion branch June 22, 2025 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants