Skip to content
71 changes: 30 additions & 41 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,40 +242,36 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm)


# Edit _get_per_token_logps to handle mixed precision
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None):
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag:
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None):
from cut_cross_entropy import linear_cross_entropy

if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

batch_size = batch_size or input_ids.size(0)
lm_head = model.get_output_embeddings().weight

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return hidden_states
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
# logits = logits[:, -logits_to_keep:]
# return logits
# logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
# return logps # compute logprobs for the input tokens
all_logps = []
for i in range(0, input_ids.size(0), batch_size):
input_ids_batch = input_ids[i : i + batch_size]
attention_mask_batch = attention_mask[i : i + batch_size]

# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
hidden_states = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1).logits
# Add dummy input_id at the end. Last logp is exluded.
input_ids_batch = torch.cat((input_ids_batch[:, -logits_to_keep:], torch.zeros((batch_size, 1), dtype=input_ids_batch.dtype, device=input_ids_batch.device)), dim=-1)
logps = -1 * linear_cross_entropy(hidden_states.to(dtype=lm_head.dtype), lm_head, input_ids_batch, reduction="none", impl="cce")
Copy link
Collaborator

@Datta0 Datta0 Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, why do we need cross entropy in get_per_token_logps?

Copy link
Collaborator

@pluesclues pluesclues Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, these will return logprobs and are equivalent to selective softmax? But I am not sure if we want to return logprobs in this matrix because like @danielhanchen said we folded it into a torch.compile kernel. I am questioning if the memory saved here is actually from the cut cross entropy loss rather than the chunked concatenation of the hidden states. I am currently at work but we can check later if chunking the hidden states conserves similar amounts of memory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I see is that according to to this person's post, there also seems to be some speed up as well, what we can do instead of materializing the logits outside of here is also put the linear_cut_cross_entropy in place of the code in selective_softmax so we get speed up and memory and do not materialize logits outside of the kernel.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pluesclues That would work too. The only reason I did it this way is to ensure consistency with HF. That being said, we may, at some point, need to write a custom kernel anyway to run fused operations on the logit matrix chunk. Currently, the implementation in HF scales the logits with temperature before computing logps (https://github.com/huggingface/trl/blob/4c92de00001379ceedaf073512ce4df5da304d08/trl/trainer/grpo_trainer.py#L871).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I just tested this method inside of the kernel, its as I suspected, we cannot use linear_cross_entropy which is a torch.compile kernel in itself inside of a torch.compile kernel, I confirmed this by running ref = -1 * linear_cross_entropy(ref_hidden_states_j[:, :-1, :].to(dtype=lm_head.dtype), lm_head, input_ids_j, reduction="none", impl="cce") right before accumulate_chunk outside the kernel and also called this line inside the kernel, outside the kernel it works just fine, inside it seems to break. I still haven't tested the speed up on my machine yet, but so far it looks like we can either merge this or just change the way we calculate logprobs to exactly how CCE does it in their kernel.

Copy link
Author

@zkpranav zkpranav Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the memory-saving speed-up I reported, I believe the manner in which I profiled it does not provide an accurate account. I am only logging the peak memory allocated throughout a training step, clearing it at the beginning. This approach fails to account for the memory allocated for the old and ref policies as they are computed and cached outside the new policy update loop, i.e., every _step % num_interations == 0. I expected much higher memory savings. I would appreciate some help with this.

Moreover, I would like to confirm that UNSLOTH_USE_NEW_MODEL being set to 0 must be interpreted as the pathway to UnslothEfficientGRPO as is the case in the current implementation.
Also, UNSLOTH_RETURN_HIDDEN_STATES is set to 1 before executing the forward pass in _get_per_token_logps but never reset to its original value, creating an unintended side-effect. This is done in a couple of places. Would it not be better to reset it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have the wandb of memory usage over time (as tracked by trl/wandb itself) of the run?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_mem

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much smaller run with double the batch size. The CCE version completes its 4 training steps in 7 mins, whereas the current implementation OOMs on my machine after 12 mins.
In this case, the amount of memory saved is roughly 25%.

batch_size = 16
unsloth_num_chunks = 4

gpu_mem_oom

all_logps.append(logps[:, :-1])
pass
pass

return torch.cat(all_logps, dim=0)
pass

function = inspect.getsource(_get_per_token_logps)
Expand Down Expand Up @@ -312,30 +308,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
else:
ref_per_token_logps = None
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# ref_per_token_logps is now cached in _buffered_inputs
# https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L1292
ref_per_token_logps = inputs["ref_per_token_logps"]

# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
if "old_per_token_logps" in inputs.keys():
old_hidden_states = inputs["old_per_token_logps"]
old_per_token_logps = inputs["old_per_token_logps"]
else:
old_hidden_states = None
input_ids = input_ids[:, -logits_to_keep:]
if per_token_logps is not None:
old_per_token_logps = None

if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1':
loss, completion_length, mean_kl = grpo_compute_loss_slow(
ref_per_token_logps, per_token_logps, old_hidden_states, input_ids, completion_mask, self.beta, advantages,
ref_per_token_logps, per_token_logps, old_per_token_logps, completion_mask, self.beta, advantages,
loss_type = self.args.loss_type,
epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
Expand All @@ -344,7 +333,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl = grpo_accumulated_loss(
self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states,
self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps,
n_chunks = self.args.unsloth_num_chunks,
loss_type = self.args.loss_type,
epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high,
Expand All @@ -354,7 +343,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states,
self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps,
n_chunks = self.args.unsloth_num_chunks,
)

Expand Down