Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,14 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
def __init__({RLConfig_arguments},
vllm_sampling_params = None,
unsloth_num_chunks = -1,
use_vision = False,
**kwargs,
):
{RLConfig_extra_args}
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
self.vllm_sampling_params = vllm_sampling_params
self.unsloth_num_chunks = unsloth_num_chunks
self.use_vision = use_vision
pass

{RLTrainer_extras}
Expand Down Expand Up @@ -233,6 +235,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):

# Edit bf16, fp16 by checking model's torch_dtype directly
extra_args = ""

# Add boolean for vision support
if "args" in call_args :
use_vision = "self.use_vision = args.use_vision\n"
extra_args += use_vision

if "args" in call_args and "model" in call_args:
mixed_precision = \
"use_bf16 = getattr(args, 'bf16', False)\n"\
Expand Down
125 changes: 116 additions & 9 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch


# Autocast precision for GRPO
def grpo_trainer__prepare_inputs(function_name, function):
if function_name != "_prepare_inputs": return function
def grpo_generate_and_score_completions(function_name, function):
if function_name != "_generate_and_score_completions":
return function

import re
# This matches the function signature, decorators and any comments immediately following
Expand All @@ -186,7 +187,6 @@ def grpo_trainer__prepare_inputs(function_name, function):
# Find where the code block starts after comments
code_start_index = match.end(1)
rest_of_function = function[code_start_index:]

# Remove any old wake_up call that might be at the start of the function body
rest_of_function = re.sub(
r"^\s*if hasattr\(self, 'llm'\):.*?self\.llm\.wake_up\(\).*?\n",
Expand All @@ -205,6 +205,91 @@ def grpo_trainer__prepare_inputs(function_name, function):
)

function = header_and_comments + insert + rest_of_function

if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function:
return function


# 1. Output pixel_values and image_grid_thw
pattern = re.compile(
r"^(?P<indent>\s*)return {\n"
r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n"
r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n"
r"(?P=indent) {4}\"completion_ids\": completion_ids,\n"
r"(?P=indent) {4}\"completion_mask\": completion_mask,\n"
r"(?P=indent) {4}\"advantages\": advantages,\n"
r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n"
r"(?P=indent)}",
re.MULTILINE
)

replacement = """ return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"old_per_token_logps": old_per_token_logps,
}"""
function = re.sub(pattern, replacement, function)

# 2. Replace the prompt_completion_ids generation
pattern = re.compile(
r"^(?P<indent>\s*)prompt_completion_ids = unwrapped_model\.generate\(\n"
r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n"
r"(?P=indent)\)",
re.MULTILINE
)

replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config)
else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)"""

function = pattern.sub(replacement, function)

# 3. Replace the old_per_token_logps generation
pattern = re.compile(
r"^(?P<indent>\s*)old_per_token_logps = self\._get_per_token_logps\(\n"
r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n"
r"(?P=indent)\)",
re.MULTILINE
)

replacement = """ old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size
)"""

function = re.sub(pattern, replacement, function)

# 4. Replace the prompt processing section
pattern = re.compile(
r"^(?P<indent>\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n"
r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n"
r"(?P=indent)prompt_inputs = self\.processing_class\(\n"
r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n"
r"(?P=indent)\)\n"
r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)"
,
re.MULTILINE
)

replacement = """ prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs]
if not self.use_vision:
pixel_values, image_grid_thw = None, None
prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
else:
images = [x['image'] for x in inputs] # Only image inputs support for now
prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']"""

function = pattern.sub(replacement, function)



# Add mixed precision training
function = function.replace(
"with torch.inference_mode():",
Expand All @@ -230,7 +315,24 @@ def grpo_trainer__prepare_inputs(function_name, function):
function = function.rstrip() + "\n " + sleep_and_cache
return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs)

RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions)

def grpo_prepare_inputs(function_name, function):
if function_name != "_prepare_inputs": return function

if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function

function = function.replace(
"generation_batch = self._generate_and_score_completions(generation_batch)",

"generation_batch = self._generate_and_score_completions(generation_batch)\n"\
" if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)"
)

return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs)


# Remove _move_model_to_vllm
Expand All @@ -249,7 +351,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
Expand Down Expand Up @@ -311,22 +413,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch

prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None)

input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# attention_mask = None
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
_input_ids = input_ids
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
_logits_to_keep = logits_to_keep
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep)

# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep)
else:
ref_per_token_logps = None
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down Expand Up @@ -376,6 +479,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
loss, completion_length, mean_kl = grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
logits_to_keep = logits_to_keep,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes for efficient path here

completion_mask = completion_mask,
advantages = advantages,
Expand All @@ -397,6 +502,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
loss, completion_length, mean_kl = grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
logits_to_keep = logits_to_keep,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

completion_mask = completion_mask,
advantages = advantages,
Expand Down