Skip to content
112 changes: 46 additions & 66 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,42 +245,51 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm)


# Edit _get_per_token_logps to handle mixed precision
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None):
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag:
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None):
from cut_cross_entropy import linear_cross_entropy
from numpy import searchsorted

if not batch_size or batch_size == input_ids.shape[0]:
# https://github.com/unslothai/unsloth-zoo/blob/4811c67f729b428735185837000449aa3d840d14/unsloth_zoo/rl_replacements.py#L281
bsz = input_ids.shape[0]
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
n_chunks = self.args.unsloth_num_chunks if self.args.unsloth_num_chunks is not None else bsz
n_chunks = factors[min(searchsorted(factors, n_chunks), len(factors)-1)]
batch_size = int(bsz / n_chunks)

if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

lm_head = model.get_output_embeddings().weight

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return hidden_states
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
# logits = logits[:, -logits_to_keep:]
# return logits
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
# logits = logits / self.temperature
# logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
# return logps # compute logprobs for the input tokens
all_logps = []
for i in range(0, input_ids.size(0), batch_size):
input_ids_batch = input_ids[i : i + batch_size]
attention_mask_batch = attention_mask[i : i + batch_size]
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
hidden_states = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1).logits

# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
# (e @ c.T) / temperature == (e / temperature) @ c.T
if (self.temperature is not None):
hidden_states = hidden_states / self.temperature

# Add dummy input_id at the end. Last logp is exluded.
input_ids_batch = torch.cat((input_ids_batch[:, -logits_to_keep:], torch.zeros((batch_size, 1), dtype=input_ids_batch.dtype, device=input_ids_batch.device)), dim=-1)
# selective_log_softmax(e @ c.T, index) == -cce(e, c, index, reduction="none”)
logps = -1 * linear_cross_entropy(hidden_states.to(dtype=lm_head.dtype), lm_head, input_ids_batch, reduction="none", filter_eps=-torch.inf, impl="cce")
all_logps.append(logps[:, :-1])
pass
pass

return torch.cat(all_logps, dim=0)
pass

function = inspect.getsource(_get_per_token_logps)
Expand Down Expand Up @@ -317,42 +326,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# ref_per_token_logps is now cached in _buffered_inputs
# https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L1292
ref_per_token_logps = inputs["ref_per_token_logps"]

# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
else:
ref_per_token_logps = None
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
if "old_per_token_logps" in inputs.keys():
old_hidden_states = inputs["old_per_token_logps"]
else:
old_hidden_states = None

input_ids = input_ids[:, -logits_to_keep:]
if per_token_logps is not None:

if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

old_per_token_logps = inputs["old_per_token_logps"]
else:
old_per_token_logps = None

if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1':
loss, completion_length, mean_kl = grpo_compute_loss_slow(
ref_per_token_logps,
per_token_logps,
old_hidden_states,
input_ids,
completion_mask,
self.beta,
advantages,
ref_per_token_logps, per_token_logps, old_per_token_logps, completion_mask, self.beta, advantages,
loss_type = self.args.loss_type,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
Expand All @@ -363,12 +353,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl = grpo_accumulated_loss(
self,
_input_ids,
logits_to_keep,
completion_mask,
advantages,
old_hidden_states,
self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps,
n_chunks = self.args.unsloth_num_chunks,
loss_type = self.args.loss_type,
epsilon_low = self.epsilon_low,
Expand All @@ -380,12 +365,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
self,
_input_ids,
logits_to_keep,
completion_mask,
advantages,
old_hidden_states,
self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps,
n_chunks = self.args.unsloth_num_chunks,
temperature = self.args.temperature,
)
Expand Down