Skip to content

fix loss masking #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

fix loss masking #345

wants to merge 2 commits into from

Conversation

RaymondLi0
Copy link
Contributor

✨ Description

  • Fix the triton implementation triton_cross_entropy_from_distribution_forward_backward_kernel

Closes #344

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

βœ… Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • πŸ“œ I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • πŸŽ‰ The functionality is complete, and I have tested the changes.
  • πŸ“ I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • πŸ‹ I have updated the Docker configuration or dependencies, if applicable.
  • πŸ”„ I have ensured compatibility with the existing setup after dependency changes.

Testing

  • πŸ§ͺ I have added or updated tests to cover my changes.
  • βœ”οΈ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • πŸ‹οΈ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • πŸ“Š I have run benchmarks where applicable to evaluate the performance impact.
  • βœ… The benchmarks show no performance regression.
  • πŸš€ The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • πŸ“ˆ I have provided benchmark results and detailed any performance impact below, if applicable.

πŸ“Š Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


πŸ—’οΈ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

per_token_loss = torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
)
loss = (per_token_loss * loss_mask).sum() / loss_mask.sum()
Copy link
Contributor

@oleksost oleksost Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can result in nans if loss_mask.sum() is 0, which can happen actually in practice in the context of reasoning SFT where prompts can be very long or when we to TP and split across sequence length dimension

So maybe better to check something like:


            if mask_sum > 0:  # can happen for inputs containing only prompts?
                loss = (loss_per_token * loss_mask).sum() / mask_sum
            else:
                loss = (loss_per_token * 0.0).mean()  # preserve grads

@RaymondLi0
Copy link
Contributor Author

As discussed with @oleksost , to finish the fix we'd also need to properly reduce the loss across micro-sequences, taking into account the sum of the loss_mask.

Now, on second thought:
With the current implementation in main: the contributions of tokens to the gradient will be the same, no matter the number of mask tokens in the sample.
Whereas if we finish this fix and go forward with it: tokens from a sample (sequence) with a lot of masked positions would contribute more to the gradient (compared to tokens from a sample without loss mask).

The question is whether we want an average of the loss over samples, or over tokens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix loss-masking for distillation?
2 participants