From 8229a9bc9073d0f0612e2cc65c800c32581b39be Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 17:06:05 +0200 Subject: [PATCH 01/16] Updated rl and rl_replacements --- unsloth/models/rl.py | 8 +++ unsloth/models/rl_replacements.py | 85 +++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3fa3fe713..f533711a5 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -134,6 +134,7 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}}, ) def __init__({RLConfig_arguments}, + use_vision = False, vllm_sampling_params = None, unsloth_num_chunks = -1, **kwargs, @@ -142,6 +143,7 @@ def __init__({RLConfig_arguments}, super().__init__({RLConfig_call_args}{RLConfig_kwargs}) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks + self.use_vision = use_vision pass {RLTrainer_extras} @@ -233,6 +235,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit bf16, fp16 by checking model's torch_dtype directly extra_args = "" + + # Add boolean for vision support + if "args" in call_args : + use_vision = "self.use_vision = args.use_vision" + extra_args += use_vision + if "args" in call_args and "model" in call_args: mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 171e75d19..4b402538b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,11 +168,67 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # Autocast precision for GRPO -def grpo_trainer__prepare_inputs(function_name, function): - if function_name != "_prepare_inputs": return function +def grpo_generate_and_score_completions(function_name, function): + if function_name != "_generate_and_score_completions": return function if "with torch.inference_mode()" not in function: return function + if """prompts = [x["prompt"] for x in inputs]""" not in function : return function + + # Add vision handling + function = function.replace( + """prompts = [x["prompt"] for x in inputs]""", + + "prompts = [x['prompt'] for x in inputs]\n"\ + "if not self.use_vision:\n" \ + " pixel_values = None\n"\ + " image_grid_thw = None\n"\ + "else:\n"\ + " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ + ) + + # Output pixel values and image grid + function = function.replace( + """return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + }""", + + """return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + }""" + ) + + + function.replace("""self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size + )""", + + """self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask,pixel_values, image_grid_thw, logits_to_keep, batch_size + )""") + function.replace("""self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size + )""", + + """self._get_per_token_logps( + self.ref_model, prompt_completion_ids, attention_mask,pixel_values, image_grid_thw, logits_to_keep, batch_size + )""") + + # Add mixed precision training function = function.replace( "with torch.inference_mode():", @@ -191,7 +247,23 @@ def grpo_trainer__prepare_inputs(function_name, function): ) return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) +RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) + +def grpo_prepare_inputs(function_name, function): + if function_name != "_prepare_inputs": return function + + if "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)" not in function : return function + + function = function.replace( + "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)", + + "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)\n"\ + "if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" + ) + + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs) # Remove _move_model_to_vllm @@ -210,7 +282,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None): + def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, calc_logprob_flag = None): if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: return None # Unsloth efficient GRPO # Otherwise, calculate normally: @@ -221,7 +293,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + if self.use_vision : + hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, logits_to_keep=logits_to_keep + 1).logits + else: + hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return hidden_states # input_ids = input_ids[:, -logits_to_keep:] From 46fce646bc7baf1111dfcbaedfb6304129096c24 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 17:54:17 +0200 Subject: [PATCH 02/16] fixed indentation --- unsloth/models/rl.py | 2 +- unsloth/models/rl_replacements.py | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f533711a5..27bb813a0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -238,7 +238,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add boolean for vision support if "args" in call_args : - use_vision = "self.use_vision = args.use_vision" + use_vision = "self.use_vision = args.use_vision \n" extra_args += use_vision if "args" in call_args and "model" in call_args: diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4b402538b..486032a7f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -173,19 +173,22 @@ def grpo_generate_and_score_completions(function_name, function): if "with torch.inference_mode()" not in function: return function - if """prompts = [x["prompt"] for x in inputs]""" not in function : return function + if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function : return function # Add vision handling function = function.replace( - """prompts = [x["prompt"] for x in inputs]""", - - "prompts = [x['prompt'] for x in inputs]\n"\ - "if not self.use_vision:\n" \ - " pixel_values = None\n"\ - " image_grid_thw = None\n"\ - "else:\n"\ - " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ - ) + """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""", + + "prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs]\n"\ + " if not self.use_vision:\n" \ + " pixel_values = None\n"\ + " image_grid_thw = None\n"\ + " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ + " else:\n"\ + " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ + " prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ + " pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']\n" + ) # Output pixel values and image grid function = function.replace( @@ -258,7 +261,7 @@ def grpo_prepare_inputs(function_name, function): "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)", "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)\n"\ - "if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" + " if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" ) return function From f5d3006b29a7815c413a461719514bb303a6e3b1 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 17:59:55 +0200 Subject: [PATCH 03/16] space error --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 486032a7f..f8f8acdd0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -183,10 +183,10 @@ def grpo_generate_and_score_completions(function_name, function): " if not self.use_vision:\n" \ " pixel_values = None\n"\ " image_grid_thw = None\n"\ - " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ + " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ " else:\n"\ " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ - " prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ + " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ " pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']\n" ) From 2d4a908373080661215867159e0a41809a98915e Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 18:02:18 +0200 Subject: [PATCH 04/16] indent fix --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f8f8acdd0..8586ba024 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -180,11 +180,11 @@ def grpo_generate_and_score_completions(function_name, function): """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""", "prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs]\n"\ - " if not self.use_vision:\n" \ + " if not self.use_vision:\n" \ " pixel_values = None\n"\ " image_grid_thw = None\n"\ " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ - " else:\n"\ + " else:\n"\ " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ " pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']\n" @@ -261,7 +261,7 @@ def grpo_prepare_inputs(function_name, function): "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)", "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)\n"\ - " if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" + " if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" ) return function From 15530f12710c32e52af02854af4049b61fb5972c Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 19:05:22 +0200 Subject: [PATCH 05/16] minor fixes --- unsloth/models/rl_replacements.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8586ba024..206056d67 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -175,7 +175,15 @@ def grpo_generate_and_score_completions(function_name, function): if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function : return function + # Add vision handling + function = function.replace( + """prompt_inputs = self.processing_class( + text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False + )""", + "" + ) + function = function.replace( """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""", @@ -186,7 +194,7 @@ def grpo_generate_and_score_completions(function_name, function): " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ " else:\n"\ " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ - " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ + " prompt_inputs = self.processing_class(images = images,text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ " pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']\n" ) @@ -224,7 +232,7 @@ def grpo_generate_and_score_completions(function_name, function): self.model, prompt_completion_ids, attention_mask,pixel_values, image_grid_thw, logits_to_keep, batch_size )""") function.replace("""self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size )""", """self._get_per_token_logps( From 7ecc622ed8e6335efc49a4f5cd84d695d89391ae Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Mon, 16 Jun 2025 19:27:43 +0200 Subject: [PATCH 06/16] working generate_and_score_completions --- unsloth/models/rl_replacements.py | 281 ++++++++++++++++++++++++------ 1 file changed, 228 insertions(+), 53 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 206056d67..8caee5919 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -171,46 +171,238 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch def grpo_generate_and_score_completions(function_name, function): if function_name != "_generate_and_score_completions": return function - if "with torch.inference_mode()" not in function: return function + def _generate_and_score_completions( + self, inputs) : + device = self.accelerator.device + mode = "eval" if self.control.should_evaluate else "train" + + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] + if not self.use_vision: + pixel_values = None + image_grid_thw = None + prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + else: + images = [x['image'] for x in inputs] # Only image inputs support for now + prompt_inputs = self.processing_class(images = images, text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw'] + + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + completion_ids = [None] * len(all_prompts_text) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] - if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function : return function + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) + else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config) + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + truncated_completions = ~is_eos.any(dim=1) + completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + with torch.no_grad(): + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's + # computation here, and use per_token_logps.detach() instead. + if self.num_iterations > 1: + old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size + ) + else: + old_per_token_logps = None - # Add vision handling - function = function.replace( - """prompt_inputs = self.processing_class( - text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False - )""", - "" - ) + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size + ) + else: + with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size + ) + + # Decode the generated completions + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) + ): + with profiling_context(self, reward_func_name): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + warnings.warn( + f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + "Please ensure that at least one reward function returns a valid reward." + ) - function = function.replace( - """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""", - - "prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs]\n"\ - " if not self.use_vision:\n" \ - " pixel_values = None\n"\ - " image_grid_thw = None\n"\ - " prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ - " else:\n"\ - " images = [x['image'] for x in inputs] # Only image inputs support for now \n"\ - " prompt_inputs = self.processing_class(images = images,text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False)\n"\ - " pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']\n" - ) + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) - # Output pixel values and image grid - function = function.replace( - """return { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "advantages": advantages, - "old_per_token_logps": old_per_token_logps, - "ref_per_token_logps": ref_per_token_logps, - }""", + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) - """return { + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + if self.scale_rewards: + advantages = advantages / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # log completion lengths, mean, min, max + agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1)) + self._metrics[mode]["completions/mean_length"].append(agg_completion_mask.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_mask.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_mask.float().max().item()) + + # identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1)) + term_completion_mask = agg_completion_mask[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_mask) == 0: + # edge case where no completed sequences are found + term_completion_mask = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_mask.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_mask.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_mask.float().max().item()) + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + + # Log prompt and completion texts + self._textual_logs["prompt"].extend(gather_object(prompts_text)) + self._textual_logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + + return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "pixel_values": pixel_values, @@ -220,26 +412,9 @@ def grpo_generate_and_score_completions(function_name, function): "advantages": advantages, "old_per_token_logps": old_per_token_logps, "ref_per_token_logps": ref_per_token_logps, - }""" - ) - + } - function.replace("""self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size - )""", - - """self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask,pixel_values, image_grid_thw, logits_to_keep, batch_size - )""") - function.replace("""self._get_per_token_logps( - self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size - )""", - - """self._get_per_token_logps( - self.ref_model, prompt_completion_ids, attention_mask,pixel_values, image_grid_thw, logits_to_keep, batch_size - )""") - - + function = inspect.getsource(_generate_and_score_completions) # Add mixed precision training function = function.replace( "with torch.inference_mode():", From d9401f7d975b1bc53c0bb3b0d8318cc4f2d84cc2 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Tue, 17 Jun 2025 00:31:34 +0200 Subject: [PATCH 07/16] working version with hidden states trimming --- unsloth/models/rl.py | 2 +- unsloth/models/rl_replacements.py | 34 +++++++++++-------------------- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 27bb813a0..f4d1e9e2a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -238,7 +238,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add boolean for vision support if "args" in call_args : - use_vision = "self.use_vision = args.use_vision \n" + use_vision = "self.use_vision = args.use_vision\n" extra_args += use_vision if "args" in call_args and "model" in call_args: diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8caee5919..da563a1c2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -181,11 +181,11 @@ def _generate_and_score_completions( if not self.use_vision: pixel_values = None image_grid_thw = None - prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False) + prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) prompt_inputs = super()._prepare_inputs(prompt_inputs) else: images = [x['image'] for x in inputs] # Only image inputs support for now - prompt_inputs = self.processing_class(images = images, text=prompts_text, return_tensors='pt', padding=True, add_special_tokens=False) + prompt_inputs = self.processing_class(images = images, text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) prompt_inputs = super()._prepare_inputs(prompt_inputs) pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw'] @@ -277,17 +277,6 @@ def _generate_and_score_completions( else: old_per_token_logps = None - if self.beta == 0.0: - ref_per_token_logps = None - elif self.ref_model is not None: - ref_per_token_logps = self._get_per_token_logps( - self.ref_model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size - ) - else: - with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter(): - ref_per_token_logps = self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size - ) # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) @@ -411,7 +400,6 @@ def _generate_and_score_completions( "completion_mask": completion_mask, "advantages": advantages, "old_per_token_logps": old_per_token_logps, - "ref_per_token_logps": ref_per_token_logps, } function = inspect.getsource(_generate_and_score_completions) @@ -438,13 +426,13 @@ def _generate_and_score_completions( def grpo_prepare_inputs(function_name, function): if function_name != "_prepare_inputs": return function - if "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)" not in function : return function + if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function function = function.replace( - "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)", + "generation_batch = self._generate_and_score_completions(generation_batch)", - "accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)\n"\ - " if self.use_vision : accumulated_local_batch['pixel_values']=accumulated_local_batch['pixel_values'].view(accumulated_local_batch['prompt_ids'].size(0), -1, accumulated_local_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" + "generation_batch = self._generate_and_score_completions(generation_batch)\n"\ + " if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" ) return function @@ -484,6 +472,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima else: hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + hidden_states = hidden_states[:, :-1, :] + hidden_states = hidden_states[:, -logits_to_keep:] return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -526,6 +516,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape @@ -534,19 +525,18 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + print(input_ids.shape) + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) From 0a8c3e2f64c9971d2ea0472203e357515fbec1eb Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Tue, 17 Jun 2025 01:10:37 +0200 Subject: [PATCH 08/16] remove print --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index da563a1c2..16b9e110f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -525,7 +525,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - print(input_ids.shape) per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) # Compute the KL divergence between the model and the reference model From a162fca5b6b09e7810c10b17da4e2112ee586c79 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Tue, 17 Jun 2025 23:44:12 +0200 Subject: [PATCH 09/16] Replace _generate_and_score_completions using function replacements instead of copy --- unsloth/models/rl_replacements.py | 316 ++++++++---------------------- 1 file changed, 87 insertions(+), 229 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 16b9e110f..2c6c5d413 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -169,241 +169,96 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # Autocast precision for GRPO def grpo_generate_and_score_completions(function_name, function): - if function_name != "_generate_and_score_completions": return function - - def _generate_and_score_completions( - self, inputs) : - device = self.accelerator.device - mode = "eval" if self.control.should_evaluate else "train" - - prompts = [x["prompt"] for x in inputs] - prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] - if not self.use_vision: - pixel_values = None - image_grid_thw = None - prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - else: - images = [x['image'] for x in inputs] # Only image inputs support for now - prompt_inputs = self.processing_class(images = images, text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw'] - - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - - if self.max_prompt_length is not None: - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - # First, have main process load weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_prompts_text = gather_object(prompts_text) - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - with profiling_context(self, "vLLM.generate"): - completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - ) - else: - completion_ids = [None] * len(all_prompts_text) - # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice. - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] + if function_name != "_generate_and_score_completions": + return function - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - else: - # Regular generation path - with unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) - else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config) - - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.processing_class.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - - # Concatenate prompt_mask with completion_mask for logit computation - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - with torch.no_grad(): - # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's - # computation here, and use per_token_logps.detach() instead. - if self.num_iterations > 1: - old_per_token_logps = self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, batch_size - ) - else: - old_per_token_logps = None + if "with torch.inference_mode()" not in function: + return function + if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function: + return function + - # Decode the generated completions - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text - - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) - ): - with profiling_context(self, reward_func_name): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = reward_processing_class( - text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = super()._prepare_inputs(reward_inputs) - with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) - else: - # Repeat all input columns (but "prompt" and "completion") to match the number of generations - keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) - # Convert None values to NaN - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - # If all reward functions return None for a given row, issue a detailed warning - if torch.isnan(rewards_per_func).all(dim=1).any(): - nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] - row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} - row_reward_kwargs["prompt"] = prompts[nan_row_idx] - row_reward_kwargs["completion"] = completions[nan_row_idx] - warnings.warn( - f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " - "Please ensure that at least one reward function returns a valid reward." - ) + # 1. Output pixel_values and image_grid_thw + pattern = re.compile( + r"^(?P\s*)return {\n" + r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n" + r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n" + r"(?P=indent) {4}\"completion_ids\": completion_ids,\n" + r"(?P=indent) {4}\"completion_mask\": completion_mask,\n" + r"(?P=indent) {4}\"advantages\": advantages,\n" + r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n" + r"(?P=indent)}", + re.MULTILINE + ) - # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the - # completions may be distributed across processes - rewards_per_func = gather(rewards_per_func) + replacement = """ return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + }""" + function = re.sub(pattern, replacement, function) + + # 2. Replace the prompt_completion_ids generation + pattern = re.compile( + r"^(?P\s*)prompt_completion_ids = unwrapped_model\.generate\(\n" + r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n" + r"(?P=indent)\)", + re.MULTILINE + ) - # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) + else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" + + function = pattern.sub(replacement, function) + + # 3. Replace the old_per_token_logps generation + pattern = re.compile( + r"^(?P\s*)old_per_token_logps = self\._get_per_token_logps\(\n" + r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n" + r"(?P=indent)\)", + re.MULTILINE +) + + replacement = """ old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size + )""" + + function = re.sub(pattern, replacement, function) + + # 4. Replace the prompt processing section + pattern = re.compile( + r"^(?P\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n" + r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n" + r"(?P=indent)prompt_inputs = self\.processing_class\(\n" + r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n" + r"(?P=indent)\)\n" + r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)" + , + re.MULTILINE + ) - # Compute grouped-wise rewards - mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) - std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + replacement = """ prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] + if not self.use_vision: + pixel_values, image_grid_thw = None, None + prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + else: + images = [x['image'] for x in inputs] # Only image inputs support for now + prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" - # Normalize the rewards to compute the advantages - mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) - std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) - advantages = rewards - mean_grouped_rewards - if self.scale_rewards: - advantages = advantages / (std_grouped_rewards + 1e-4) + function = pattern.sub(replacement, function) - # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # log completion lengths, mean, min, max - agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1)) - self._metrics[mode]["completions/mean_length"].append(agg_completion_mask.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_mask.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_mask.float().max().item()) - - # identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1)) - term_completion_mask = agg_completion_mask[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_mask) == 0: - # edge case where no completed sequences are found - term_completion_mask = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_mask.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_mask.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_mask.float().max().item()) - - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) - for i, reward_func_name in enumerate(self.reward_func_names): - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) - std_rewards = nanstd(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards) - self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) - self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) - - # Log prompt and completion texts - self._textual_logs["prompt"].extend(gather_object(prompts_text)) - self._textual_logs["completion"].extend(gather_object(completions_text)) - for i, name in enumerate(self.reward_func_names): - self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - - return { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "pixel_values": pixel_values, - "image_grid_thw": image_grid_thw, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "advantages": advantages, - "old_per_token_logps": old_per_token_logps, - } - - function = inspect.getsource(_generate_and_score_completions) - # Add mixed precision training + # Add autocast for mixed precision function = function.replace( "with torch.inference_mode():", @@ -419,6 +274,7 @@ def _generate_and_score_completions( "self.accelerator.unwrap_model(self.model)", "self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)", ) + return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) @@ -472,8 +328,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima else: hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - hidden_states = hidden_states[:, :-1, :] - hidden_states = hidden_states[:, -logits_to_keep:] + hidden_states = hidden_states[:, :-1, :] + hidden_states = hidden_states[:, -logits_to_keep:, :] return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -516,8 +372,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -525,6 +381,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -536,6 +393,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) From 32571e94ef20f8ca5c5f6a6aa3441c8835c39509 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Tue, 17 Jun 2025 23:46:18 +0200 Subject: [PATCH 10/16] typo correction --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2c6c5d413..087dc1d0e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -223,7 +223,7 @@ def grpo_generate_and_score_completions(function_name, function): r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n" r"(?P=indent)\)", re.MULTILINE -) + ) replacement = """ old_per_token_logps = self._get_per_token_logps( self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size @@ -277,6 +277,7 @@ def grpo_generate_and_score_completions(function_name, function): return function pass + RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) def grpo_prepare_inputs(function_name, function): From 3edd2545af90f1ef3835fd1680891662b32a60f9 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Wed, 18 Jun 2025 14:07:43 +0200 Subject: [PATCH 11/16] indentation fix --- unsloth/models/rl_replacements.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 087dc1d0e..1424f6211 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -192,7 +192,7 @@ def grpo_generate_and_score_completions(function_name, function): re.MULTILINE ) - replacement = """ return { + replacement = """ return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "pixel_values": pixel_values, @@ -201,7 +201,7 @@ def grpo_generate_and_score_completions(function_name, function): "completion_mask": completion_mask, "advantages": advantages, "old_per_token_logps": old_per_token_logps, - }""" + }""" function = re.sub(pattern, replacement, function) # 2. Replace the prompt_completion_ids generation @@ -212,8 +212,8 @@ def grpo_generate_and_score_completions(function_name, function): re.MULTILINE ) - replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) - else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" + replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) + else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" function = pattern.sub(replacement, function) @@ -243,17 +243,17 @@ def grpo_generate_and_score_completions(function_name, function): re.MULTILINE ) - replacement = """ prompts = [x["prompt"] for x in inputs] - prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] - if not self.use_vision: - pixel_values, image_grid_thw = None, None - prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - else: - images = [x['image'] for x in inputs] # Only image inputs support for now - prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" + replacement = """ prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] + if not self.use_vision: + pixel_values, image_grid_thw = None, None + prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + else: + images = [x['image'] for x in inputs] # Only image inputs support for now + prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" function = pattern.sub(replacement, function) From af3f5e628fcd75c84fdd65fd4b504821a7062c9b Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Thu, 19 Jun 2025 09:20:30 +0200 Subject: [PATCH 12/16] fixed _get_per_token_logps slicing --- unsloth/models/rl_replacements.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 1424f6211..b70c46e0b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -329,8 +329,12 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima else: hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - hidden_states = hidden_states[:, :-1, :] - hidden_states = hidden_states[:, -logits_to_keep:, :] + + hidden_states = hidden_states[:, :-1, :] # if not using fast path, we need to slice the last logit (also see PR #2702 from unsloth ) + + if hidden_states.size(1) != logits_to_keep + 1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually + hidden_states = hidden_states[:, -logits_to_keep:, :] + return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. From 8e3fe8e4abd4c0b13a10bb03b8acd9327588c631 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Thu, 19 Jun 2025 09:45:19 +0200 Subject: [PATCH 13/16] slicing condition was off --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b70c46e0b..1d1aef535 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -332,7 +332,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima hidden_states = hidden_states[:, :-1, :] # if not using fast path, we need to slice the last logit (also see PR #2702 from unsloth ) - if hidden_states.size(1) != logits_to_keep + 1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually + if hidden_states.size(1) != logits_to_keep : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually hidden_states = hidden_states[:, -logits_to_keep:, :] return hidden_states From 9dbe0b9e144527c79f7aeb8332c9ae3f65d71cc3 Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Tue, 1 Jul 2025 09:38:59 +0200 Subject: [PATCH 14/16] spacing + arg on new line fix --- unsloth/models/rl.py | 2 +- unsloth/models/rl_replacements.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a34f98e30..897981008 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -134,9 +134,9 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}}, ) def __init__({RLConfig_arguments}, - use_vision = False, vllm_sampling_params = None, unsloth_num_chunks = -1, + use_vision = False, **kwargs, ): {RLConfig_extra_args} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3f09c8c5e..f6cf5d969 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -363,9 +363,19 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded if self.use_vision : - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, logits_to_keep=logits_to_keep + 1).logits + hidden_states = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + logits_to_keep=logits_to_keep + 1 + ).logits else: - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + hidden_states = model( + input_ids=input_ids, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep + 1 + ).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred if hidden_states.size(1) != logits_to_keep+1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually @@ -416,23 +426,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) + pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids - _logits_to_keep = logits_to_keep - - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) + _logits_to_keep = logits_to_keep + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) + ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From dfb05c148846110593b02897291b914b7b0b85eb Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Thu, 3 Jul 2025 00:32:36 +0200 Subject: [PATCH 15/16] efficient vlm grpo compute loss --- unsloth/models/rl_replacements.py | 35 ++++++++++++------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f6cf5d969..cad46333b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -351,8 +351,8 @@ def _move_model_to_vllm(self, *args, **kwargs): return None def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, calc_logprob_flag = None): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: + def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep): + if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): @@ -362,26 +362,13 @@ def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,ima os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - if self.use_vision : - hidden_states = model( - input_ids=input_ids, - attention_mask=attention_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - logits_to_keep=logits_to_keep + 1 - ).logits - else: - hidden_states = model( - input_ids=input_ids, - attention_mask=attention_mask, - logits_to_keep=logits_to_keep + 1 - ).logits - #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - - if hidden_states.size(1) != logits_to_keep+1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually - hidden_states = hidden_states[:, -(logits_to_keep+1):, :] - - return hidden_states + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + logits_to_keep = logits_to_keep + 1, + ).logits + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + return logits # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 @@ -492,6 +479,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, @@ -513,6 +502,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, From 12784ded01276bf90128ca6bc03d86e829e0a6da Mon Sep 17 00:00:00 2001 From: GAD-cell Date: Fri, 11 Jul 2025 08:50:19 +0200 Subject: [PATCH 16/16] fix --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e8356611a..a2953aec9 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -479,7 +479,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch get_logps_func = \ lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \ - self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \ + self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) \ if hasattr(self, "_get_per_token_logps") else \ self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)['logps']