-
-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Feature] VLMs support for GRPO #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
It's implemented for a specific input type : { There are still tasks to complete, particularly regarding the compute loss. |
Fantastic work! |
unsloth/models/rl_replacements.py
Outdated
if not self.use_vision: | ||
pixel_values = None | ||
image_grid_thw = None | ||
prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
else: | ||
images = [x['image'] for x in inputs] # Only image inputs support for now | ||
prompt_inputs = self.processing_class(images = images, text=prompts_text, return_tensors='pt', padding=True,padding_side="left", add_special_tokens=False) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is the only change along with this
unsloth/models/rl_replacements.py
Outdated
"completion_mask": completion_mask, | ||
"advantages": advantages, | ||
"old_per_token_logps": old_per_token_logps, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes and the return output also as well as here
So there are 4 changes in total.
Hey @GAD-cell thanks for the changes. Can you please provide a screenshot of the generated changes UnslothGRPOTrainer.py and paste it in PR desc. |
Ok @Datta0 here are the generated changes in UnslothGRPOTrainer : |
BTW, I still need to implement the code for grpo_accumulated_loss. In my version, it assumes that the slow compute loss is used. I also have a question: There are two paths for computing the loss: The code uses grpo_accumulated_loss if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0'. So I'm not entirely sure when grpo_accumulated_loss is actually supposed to be used. My guess: as GRPO was developed only for language models at first that was not an issue. |
update @Datta0 @danielhanchen : logits_to_keep is not a parameter of the forward pass for Qwen-VL models (see qwenVL forward parameters ( so I had to slice manually (see last commit). can you confirm ? I can also add vlm support for the fast path, just need to change few things in grpo_accumulated_loss and UnslothEfficientGRPO. |
Hey @GAD-cell, I haven't looked the entire code, just saw the last commit. |
Thank you ! |
@GAD-cell Nice work again! Would it be possible to confirm if say the original Unsloth Qwen 4B GRPO notebook on our main Github page works as expected after your changes? Appreciate it Also would it be possible to provide a full working notebook? The goal is to highlight your work in the notebook itself (ie made by you), and we'll post about it! |
Thank you! I just tested with your Qwen3 4B notebook, everything works correctly, including the training! |
Hey @danielhanchen ! |
@GAD-cell Oh the notebook looks very nice - great work! There are some spelling errors :) Also maybe add a sentence somewhere notebook contributed by GAD-cell with a hyperlink (if you want) Then also move the notebook to the notebooks repo in Unsloth :) There are also some merge conflicts :) After that @Datta0 Could you maybe run the notebook once and see if everything functions well! |
Ok that's strange, I reproduced and it's working for me. |
I just opened the notebook in new instance. The first install of unsloth on the session. So it should not be a unsloth_compiled_cache thingy. I am using T4 but I don't think it should matter anyhow right. |
I tried again in a new instance and it's still working for me haha. Can't figure out what's going wrong. I did this (you can find the second part in the cell "Colab Extra install" ): import os
! pip install git+https://github.com/GAD-cell/unsloth.git@VLM_GRPO
!pip install --no-deps unsloth vllm==0.8.5.post1
import sys, re, requests; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
#added for this specific notebook
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
!pip install --no-deps -U transformers
!pip install --no-deps -U accelerate
!pip install --no-deps trl==0.18.2
# vLLM requirements - vLLM breaks Colab due to reinstalling numpy
f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
with open("vllm_requirements.txt", "wb") as file:
file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
!pip install -r vllm_requirements.txt I'm going to try with a T4. Let me know if this worked |
Ok my bad :) |
Ok ok perfect :) glad it worked haha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Great work :)
Thank you for your time ! |
Hey @danielhanchen. I've resolved all the conflicts and tested again the VL GRPO notebook and the regular GRPO notebook. |
unsloth/models/rl_replacements.py
Outdated
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
|
||
if hidden_states.size(1) != logits_to_keep+1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I think we do this automatically in kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you do it in grpo_accumulated_loss, but for now vlm grpo uses grpo_compute_loss_slow. And in that case It needs to be trimmed. I've commented about this here. We can implement it for fast path but I didn't want to touch that part yet since the flag is not clear for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh and sorry for the spacing I forgot to double-check it :)
Ok @danielhanchen, I've made the necessary changes. Apologies again, there was some confusion around get_per_token_logps, and I didn’t realize it now always returns None. |
input_ids = _input_ids, | ||
pixel_values = pixel_values, | ||
image_grid_thw = image_grid_thw, | ||
logits_to_keep = logits_to_keep, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes for efficient path here
input_ids = _input_ids, | ||
pixel_values = pixel_values, | ||
image_grid_thw = image_grid_thw, | ||
logits_to_keep = logits_to_keep, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here
Hey there! I was testing out this branch to use GRPO for text-based tasks on a model that supports vision (A Qwen2.5-VL 3B model that I had primed already). I keep getting this error. I will paste my notebook here to see if it is an issue with my code (Might be, I Frankensteined it) but it is an issue with HF Transformers, so I don't know. |
Hey ! |
Sorry! I can paste the errors ASAP. I am using it all locally, and not running this on Google colab, hence why I don't have the other dependencies installed. |
This was the error that I am getting:
|
This is due to the transformers version. |
Ah! Thank you so much. It is working now! |
One thing I noticed. When I went to run PPO fine tuning for a different model I ended up getting this error if I used this version.
And when I add |
@GAD-cell Does this support G3emma 3n? It seems it's not compatible with the newest transformers version that's needed for 3n: |
This PR aims to add support for VLMs in GRPO, which is currently not supported by HF.
I've implemented a working version that does not yet include VLLM or video input support (mainly due to limited resources for testing video inputs haha).
I added a new variable, use_vision, to the GRPO config. Setting use_vision = True enables vision inputs, while use_vision = False keeps the default GRPO behavior. Default is False.
I also had to change a function in unsloth_zoo.peft_utils (requires_grad_post_hook) to make it work.
I've tested the implementation with Qwen 2.5 VL 7B for 250 steps, and training appears to proceed correctly (see TensorBoard screenshots for reference).