diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 889bbd480..897981008 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -136,12 +136,14 @@ class Unsloth{RLConfig_name}({RLConfig_name}): def __init__({RLConfig_arguments}, vllm_sampling_params = None, unsloth_num_chunks = -1, + use_vision = False, **kwargs, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks + self.use_vision = use_vision pass {RLTrainer_extras} @@ -233,6 +235,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit bf16, fp16 by checking model's torch_dtype directly extra_args = "" + + # Add boolean for vision support + if "args" in call_args : + use_vision = "self.use_vision = args.use_vision\n" + extra_args += use_vision + if "args" in call_args and "model" in call_args: mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d5b4d7a4..a2953aec9 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -170,8 +170,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # Autocast precision for GRPO -def grpo_trainer__prepare_inputs(function_name, function): - if function_name != "_prepare_inputs": return function +def grpo_generate_and_score_completions(function_name, function): + if function_name != "_generate_and_score_completions": + return function import re # This matches the function signature, decorators and any comments immediately following @@ -188,7 +189,6 @@ def grpo_trainer__prepare_inputs(function_name, function): # Find where the code block starts after comments code_start_index = match.end(1) rest_of_function = function[code_start_index:] - # Remove any old wake_up call that might be at the start of the function body rest_of_function = re.sub( r"^\s*if hasattr\(self, 'llm'\):.*?self\.llm\.wake_up\(\).*?\n", @@ -207,6 +207,91 @@ def grpo_trainer__prepare_inputs(function_name, function): ) function = header_and_comments + insert + rest_of_function + + if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function: + return function + + + # 1. Output pixel_values and image_grid_thw + pattern = re.compile( + r"^(?P\s*)return {\n" + r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n" + r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n" + r"(?P=indent) {4}\"completion_ids\": completion_ids,\n" + r"(?P=indent) {4}\"completion_mask\": completion_mask,\n" + r"(?P=indent) {4}\"advantages\": advantages,\n" + r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n" + r"(?P=indent)}", + re.MULTILINE + ) + + replacement = """ return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + }""" + function = re.sub(pattern, replacement, function) + + # 2. Replace the prompt_completion_ids generation + pattern = re.compile( + r"^(?P\s*)prompt_completion_ids = unwrapped_model\.generate\(\n" + r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n" + r"(?P=indent)\)", + re.MULTILINE + ) + + replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) + else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" + + function = pattern.sub(replacement, function) + + # 3. Replace the old_per_token_logps generation + pattern = re.compile( + r"^(?P\s*)old_per_token_logps = self\._get_per_token_logps\(\n" + r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n" + r"(?P=indent)\)", + re.MULTILINE + ) + + replacement = """ old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size + )""" + + function = re.sub(pattern, replacement, function) + + # 4. Replace the prompt processing section + pattern = re.compile( + r"^(?P\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n" + r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n" + r"(?P=indent)prompt_inputs = self\.processing_class\(\n" + r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n" + r"(?P=indent)\)\n" + r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)" + , + re.MULTILINE + ) + + replacement = """ prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] + if not self.use_vision: + pixel_values, image_grid_thw = None, None + prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + else: + images = [x['image'] for x in inputs] # Only image inputs support for now + prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" + + function = pattern.sub(replacement, function) + + + # Add mixed precision training function = function.replace( "with torch.inference_mode():", @@ -232,7 +317,24 @@ def grpo_trainer__prepare_inputs(function_name, function): function = function.rstrip() + "\n " + sleep_and_cache return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) + +RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) + +def grpo_prepare_inputs(function_name, function): + if function_name != "_prepare_inputs": return function + + if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function + + function = function.replace( + "generation_batch = self._generate_and_score_completions(generation_batch)", + + "generation_batch = self._generate_and_score_completions(generation_batch)\n"\ + " if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" + ) + + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs) # Remove _move_model_to_vllm @@ -251,7 +353,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep): if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: @@ -365,6 +467,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -375,7 +479,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch get_logps_func = \ lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \ - self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \ + self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) \ if hasattr(self, "_get_per_token_logps") else \ self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)['logps'] @@ -436,6 +540,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, @@ -457,6 +563,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages,