diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 18f772056..6b1063c32 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -245,42 +245,51 @@ def _move_model_to_vllm(self, *args, **kwargs): return None pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) - -# Edit _get_per_token_logps to handle mixed precision def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: - return None # Unsloth efficient GRPO - # Otherwise, calculate normally: + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None): + from cut_cross_entropy import linear_cross_entropy + from numpy import searchsorted + + if not batch_size or batch_size == input_ids.shape[0]: + # https://github.com/unslothai/unsloth-zoo/blob/4811c67f729b428735185837000449aa3d840d14/unsloth_zoo/rl_replacements.py#L281 + bsz = input_ids.shape[0] + factors = [i for i in range(1, bsz + 1) if bsz % i == 0] + n_chunks = self.args.unsloth_num_chunks if self.args.unsloth_num_chunks is not None else bsz + n_chunks = factors[min(searchsorted(factors, n_chunks), len(factors)-1)] + batch_size = int(bsz / n_chunks) + if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 + lm_head = model.get_output_embeddings().weight + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits - #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - return hidden_states - # input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - # logits = logits[:, -logits_to_keep:] - # return logits - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - # logits = logits / self.temperature - # logps = selective_log_softmax(logits, input_ids) - - # row_indices, col_indices = torch.where(logps < -20) - - # # Method 1: Check if tensors have elements - # if len(row_indices) > 0 and len(col_indices) > 0: - # breakpoint() # Breakpoint triggered here - # print("Found high values!") - # return logps # compute logprobs for the input tokens + all_logps = [] + for i in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[i : i + batch_size] + attention_mask_batch = attention_mask[i : i + batch_size] + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + hidden_states = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1).logits + + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # (e @ c.T) / temperature == (e / temperature) @ c.T + if (self.temperature is not None): + hidden_states = hidden_states / self.temperature + + # Add dummy input_id at the end. Last logp is exluded. + input_ids_batch = torch.cat((input_ids_batch[:, -logits_to_keep:], torch.zeros((batch_size, 1), dtype=input_ids_batch.dtype, device=input_ids_batch.device)), dim=-1) + # selective_log_softmax(e @ c.T, index) == -cce(e, c, index, reduction="none”) + logps = -1 * linear_cross_entropy(hidden_states.to(dtype=lm_head.dtype), lm_head, input_ids_batch, reduction="none", filter_eps=-torch.inf, impl="cce") + all_logps.append(logps[:, :-1]) + pass pass + + return torch.cat(all_logps, dim=0) pass function = inspect.getsource(_get_per_token_logps) @@ -317,42 +326,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch _logits_to_keep = logits_to_keep per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # ref_per_token_logps is now cached in _buffered_inputs + # https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L1292 + ref_per_token_logps = inputs["ref_per_token_logps"] - # Compute the KL divergence between the model and the reference model - # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. - # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 - if self.beta != 0.0: - with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - else: - ref_per_token_logps = None - # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() if "old_per_token_logps" in inputs.keys(): - old_hidden_states = inputs["old_per_token_logps"] - else: - old_hidden_states = None - - input_ids = input_ids[:, -logits_to_keep:] - if per_token_logps is not None: - - if ref_per_token_logps is not None: - ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - - per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - + old_per_token_logps = inputs["old_per_token_logps"] + else: + old_per_token_logps = None + + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1': loss, completion_length, mean_kl = grpo_compute_loss_slow( - ref_per_token_logps, - per_token_logps, - old_hidden_states, - input_ids, - completion_mask, - self.beta, - advantages, + ref_per_token_logps, per_token_logps, old_per_token_logps, completion_mask, self.beta, advantages, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, @@ -363,12 +353,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: if hasattr(self.args, "loss_type"): loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, @@ -380,12 +365,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + self, completion_mask, advantages, ref_per_token_logps, per_token_logps, old_per_token_logps, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, )