-
-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Avoid materializing the entire logit matrix for logp calculations. #2772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zkpranav
wants to merge
10
commits into
unslothai:main
Choose a base branch
from
zkpranav:cce-based-logp-calc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
0629f22
Modify _get_per_token_logps to use CCE, remove calc_logprob_flag and…
zkpranav a89b99f
Use UNSLOTH_USE_NEW_MODEL as a check to decide between grpo_accumulat…
zkpranav a8207c8
Integrate UnslothEfficientGRPO changes
zkpranav 1389f22
Removes UNSLOTH_RETURN_HIDDEN_STATES invalid access
zkpranav f27c920
Add back sleep_mode changes
zkpranav 5228cab
Removes grad filtering, use batch chunking strategy similar to grpo_a…
zkpranav a5c0bd2
Resolve conflicts
zkpranav 9211e01
Perform logit scaling with temperature
zkpranav d467c05
Assume a default temperature of 1.0
zkpranav 75f36ba
No default temperature _get_per_token_logps, corrects check
zkpranav File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um, why do we need cross entropy in
get_per_token_logps
?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently, these will
return logprobs
and are equivalent toselective softmax
? But I am not sure if we want to return logprobs in this matrix because like @danielhanchen said we folded it into a torch.compile kernel. I am questioning if the memory saved here is actually from the cut cross entropy loss rather than the chunked concatenation of the hidden states. I am currently at work but we can check later if chunking the hidden states conserves similar amounts of memory.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing I see is that according to to this person's post, there also seems to be some speed up as well, what we can do instead of materializing the logits outside of here is also put the
linear_cut_cross_entropy
in place of the code inselective_softmax
so we get speed up and memory and do not materialize logits outside of the kernel.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pluesclues That would work too. The only reason I did it this way is to ensure consistency with HF. That being said, we may, at some point, need to write a custom kernel anyway to run fused operations on the logit matrix chunk. Currently, the implementation in HF scales the logits with temperature before computing logps (https://github.com/huggingface/trl/blob/4c92de00001379ceedaf073512ce4df5da304d08/trl/trainer/grpo_trainer.py#L871).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I just tested this method inside of the kernel, its as I suspected, we cannot use
linear_cross_entropy
which is atorch.compile
kernel in itself inside of atorch.compile kernel
, I confirmed this by runningref = -1 * linear_cross_entropy(ref_hidden_states_j[:, :-1, :].to(dtype=lm_head.dtype), lm_head, input_ids_j, reduction="none", impl="cce")
right beforeaccumulate_chunk
outside the kernel and also called this line inside the kernel, outside the kernel it works just fine, inside it seems to break. I still haven't tested the speed up on my machine yet, but so far it looks like we can either merge this or just change the way we calculate logprobs to exactly how CCE does it in their kernel.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the memory-saving speed-up I reported, I believe the manner in which I profiled it does not provide an accurate account. I am only logging the peak memory allocated throughout a training step, clearing it at the beginning. This approach fails to account for the memory allocated for the old and ref policies as they are computed and cached outside the new policy update loop, i.e., every
_step % num_interations == 0
. I expected much higher memory savings. I would appreciate some help with this.Moreover, I would like to confirm that
UNSLOTH_USE_NEW_MODEL
being set to0
must be interpreted as the pathway toUnslothEfficientGRPO
as is the case in the current implementation.Also,
UNSLOTH_RETURN_HIDDEN_STATES
is set to1
before executing the forward pass in_get_per_token_logps
but never reset to its original value, creating an unintended side-effect. This is done in a couple of places. Would it not be better to reset it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have the wandb of memory usage over time (as tracked by trl/wandb itself) of the run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a much smaller run with double the batch size. The CCE version completes its 4 training steps in 7 mins, whereas the current implementation OOMs on my machine after 12 mins.
In this case, the amount of memory saved is roughly 25%.
batch_size = 16
unsloth_num_chunks = 4