Skip to content

grpo liger loss #3781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
c3f859d
liger grpo loss
hjh0119 Apr 7, 2025
5224a4a
merge main
hjh0119 Apr 14, 2025
bbce4b2
update
hjh0119 Apr 14, 2025
63fdcea
fix
hjh0119 Apr 14, 2025
5915901
move args
hjh0119 Apr 14, 2025
d0c290c
fix
hjh0119 Apr 14, 2025
0a3794f
fix
hjh0119 Apr 14, 2025
3b9ee6d
fix
hjh0119 Apr 14, 2025
d643ab9
fix
hjh0119 Apr 14, 2025
93fdb71
require
hjh0119 Apr 15, 2025
f87b042
compatible with zero3
hjh0119 Apr 15, 2025
b82cbf4
fix
hjh0119 Apr 15, 2025
9c20051
merge main
hjh0119 May 1, 2025
fc7fabe
wip
hjh0119 May 1, 2025
8f67b13
update liger loss
hjh0119 May 1, 2025
8b4e346
liger&peft
hjh0119 May 1, 2025
edc1fd1
init
hjh0119 May 6, 2025
07a1040
fix default
hjh0119 May 6, 2025
0303461
fix
hjh0119 May 7, 2025
854f357
fix seed
hjh0119 May 7, 2025
7df2b5d
fix
hjh0119 May 7, 2025
fda82ee
wip
hjh0119 May 7, 2025
5d8d4a2
wip multi turn
hjh0119 May 7, 2025
ac52340
multi turn
hjh0119 May 7, 2025
578a365
fix comment
hjh0119 May 7, 2025
9a49fb5
fix peft model inspect and labels
hjh0119 May 7, 2025
5579c3e
fix multi turn
hjh0119 May 7, 2025
7de8aab
update multi turn
hjh0119 May 7, 2025
438f1f7
multi turn not remove response
hjh0119 May 8, 2025
d69a9ae
fix
hjh0119 May 8, 2025
451fd02
fix multi turn concate response
hjh0119 May 8, 2025
c3a1aa9
fix multi turn message check
hjh0119 May 8, 2025
300610e
fix infer
hjh0119 May 8, 2025
fd08ccd
external async generate
hjh0119 May 8, 2025
9da6242
clean argument check
hjh0119 May 8, 2025
8a22c9b
fix async generate
hjh0119 May 8, 2025
8ba0330
fix server infer to list
hjh0119 May 8, 2025
0926a3c
fix server infer
hjh0119 May 8, 2025
0c3827a
catch async generate error
hjh0119 May 8, 2025
fbc2b54
fix infer inputs
hjh0119 May 8, 2025
57445b4
fix async generate
hjh0119 May 8, 2025
e2330f9
fix size
hjh0119 May 8, 2025
37a06f9
remove vllm context
hjh0119 May 9, 2025
66ad138
reward model prepare ds
hjh0119 May 9, 2025
a1f1636
merge main
hjh0119 May 12, 2025
f4a05d3
lint
hjh0119 May 12, 2025
2b5198e
fix multi turn + TP
hjh0119 May 12, 2025
a479465
external path image
hjh0119 May 12, 2025
1fb25db
fix async generate and doc
hjh0119 May 12, 2025
7394dc9
update doc
hjh0119 May 12, 2025
4160ad3
remove async mode script
hjh0119 May 12, 2025
47bb902
doc wip and deprecate patch
hjh0119 May 12, 2025
37c68d2
lint
hjh0119 May 12, 2025
f7700fa
doc and scipt wip
hjh0119 May 13, 2025
6a572fa
doc update
hjh0119 May 13, 2025
4afbdc3
doc
hjh0119 May 13, 2025
df2ce3d
doc update
hjh0119 May 13, 2025
b101e4b
doc update
hjh0119 May 13, 2025
1939873
update doc and readme
hjh0119 May 13, 2025
dae81c1
update grpo doc
hjh0119 May 13, 2025
05054d0
update scripts
hjh0119 May 13, 2025
11307be
rm script
hjh0119 May 13, 2025
7bbed3f
update completion_length_limit_scope argument
hjh0119 May 13, 2025
53a08d0
merge refactor
hjh0119 May 13, 2025
829a7ea
fix epsilon
hjh0119 May 13, 2025
f2b4aac
update stable doc reference
hjh0119 May 13, 2025
cb7ff52
remove lmdeploy
hjh0119 May 13, 2025
5e9e3b5
set different seed bewteen processes
hjh0119 May 13, 2025
25ac346
fix seed
hjh0119 May 13, 2025
427a32f
merge refactor
hjh0119 May 13, 2025
c4dc72e
merge main
hjh0119 May 13, 2025
346396f
remove liger check
hjh0119 May 13, 2025
3045802
fix epsilon
hjh0119 May 13, 2025
4bf7996
remvoe unused import
hjh0119 May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements/install_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ pip install timm -U
pip install deepspeed -U
pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U
pip install liger_kernel nvitop pre-commit -U
pip install wandb
pip install math_verify==0.5.2
# flash-attn: https://github.com/Dao-AILab/flash-attention/releases
7 changes: 6 additions & 1 deletion swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,18 @@ def _set_default(self):
def _check_grpo(self):
if self.rlhf_type != 'grpo':
return

from packaging import version

import trl
trl_version = version.parse(trl.__version__)
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
'Please update it by running: pip install -U trl')

if self.use_liger_loss:
from trl.import_utils import is_liger_kernel_available
assert is_liger_kernel_available(), (
'Please install/update liger-kernel by running: pip install -U liger-kernel')

if self.num_generations < 2:
raise ValueError(
'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,8 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs):
old_kwargs = to_device(kwargs, model.device)
kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
for k, v in old_kwargs.items():
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states'
} and k not in kwargs:
kwargs[k] = v
if 'inputs_embeds' in kwargs:
kwargs.pop('input_ids', None)
Expand Down
2 changes: 2 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class GRPOArgumentsMixin:
# dataset
dataset_shuffle: Optional[bool] = True

use_liger_loss: bool = False


@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/rlhf_trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .ppo_trainer import PPOTrainer
from .reward_trainer import RewardTrainer
from .rlhf_mixin import RLHFTrainerMixin
from .utils import patch_lora_merge, patch_lora_unmerge, round_robin
from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection
else:
_import_structure = {
'cpo_trainer': ['CPOTrainer'],
Expand All @@ -23,7 +23,7 @@
'ppo_trainer': ['PPOTrainer'],
'reward_trainer': ['RewardTrainer'],
'rlhf_mixin': ['RLHFTrainerMixin'],
'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin'],
'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'],
}

import sys
Expand Down
118 changes: 101 additions & 17 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from swift.utils import JsonlWriter, gc_collect, get_device, get_logger, is_vllm_available, is_wandb_available
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation
from .vllm_client import VLLMClient

del HFGRPOTrainer.__init__
Expand Down Expand Up @@ -179,6 +179,25 @@ def __init__(self,
vllm_client = kwargs.pop('vllm_client') # for external vllm

super().__init__(model, ref_model, *_args, **kwargs)
# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon

self.use_liger_loss = self.args.use_liger_loss
if self.use_liger_loss:
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss

self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
beta=self.beta,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
temperature=self.temperature,
use_ref_model=self.beta != 0.0,
loss_type=self.loss_type,
max_completion_length=self.max_completion_length,
)
self._forward_redirection = _ForwardRedirection()

self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self.log_completions = args.log_completions
Expand Down Expand Up @@ -275,11 +294,6 @@ def __init__(self,
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True, device_placement=True)

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon

# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
Expand Down Expand Up @@ -495,7 +509,7 @@ def _move_model_to_vllm(self):
if self.args.async_generate:
# before sync weight, we should wait async generate finish
self._wait_queue()
if self.args.use_vllm:
if self.use_vllm:
llm_model = self.engine.inner_model
else:
llm_model = self.engine.engine.engine
Expand Down Expand Up @@ -949,15 +963,6 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li
batch_encoded_inputs['old_per_token_logps'] = (
self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None)

if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs)
batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps

ga_batch_encoded_inputs.append(batch_encoded_inputs)

return ga_batch_encoded_inputs
Expand Down Expand Up @@ -1004,6 +1009,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if isinstance(inputs, list):
assert len(inputs) == 1
inputs = inputs[0]
if self.use_liger_loss:
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
else:
return self._compute_loss(model, inputs)

def _compute_loss(self, model, inputs):
completion_mask = inputs['completion_mask']
truncated_mask = inputs['truncated_mask']
# apply the completion_mask to exclude loss and metrics for overlong completions
Expand All @@ -1017,7 +1029,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

# Compute the KL divergence between the model and the reference model
if self.beta != 0.0:
ref_per_token_logps = inputs['ref_per_token_logps']
with torch.no_grad():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, inputs)

per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1)

Expand Down Expand Up @@ -1096,6 +1114,72 @@ def _get_per_token_logps(self, model, inputs):
input_ids = input_ids[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

@profiling_decorator
def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
# unwrap the model to access the model.model
if is_peft_model(unwrapped_model):
unwrapped_model = unwrapped_model.base_model.model
if not unwrapped_model.model_meta.is_multimodal:
last_hidden_state = unwrapped_model.model(
input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).last_hidden_state
else:
inputs = {
k: v
for k, v in inputs.items() if k not in [
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
'truncated_mask'
]
}
with self._template_context(self.template):
outputs = unwrapped_model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1]

last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
if logits_to_keep is not None:
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
return last_hidden_state

def compute_liger_loss(self, unwrapped_model, inputs):
# Compute the per-token log probabilities for the model
input_ids = inputs['input_ids']
logits_to_keep = inputs['logits_to_keep']
completion_ids = input_ids[:, -logits_to_keep:]
completion_mask = inputs['completion_mask']

# Compute the KL divergence between the model and the reference model
ref_per_token_logps = None
if self.beta != 0.0:
with torch.no_grad():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, inputs)

# get the last hidden state of the model
last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep)
# compute loss and metrics using liger grpo loss
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=completion_mask,
advantages=inputs['advantages'],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs['old_per_token_logps'],
ref_per_token_logps=ref_per_token_logps,
)
# Extract metrics from the liger_grpo_loss output
# KL divergence is the first metric when beta is non-zero
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]

mode = 'eval' if self.control.should_evaluate else 'train'
if self.beta != 0.0:
self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
self._metrics[mode]['clip_ratio'].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss

def evaluation_loop(self, dataloader, *args, **kwargs):
# Wait for the training rollout to complete
if self.args.async_generate:
Expand Down
46 changes: 46 additions & 0 deletions swift/trainers/rlhf_trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from peft.tuners import lora
from peft.tuners.lora import LoraLayer
from torch import nn


def round_robin(num_reqs, num_workers):
Expand Down Expand Up @@ -157,3 +158,48 @@ def unwrap_model_for_generation(
add_hooks(model)
else:
yield unwrapped_model


class _ForwardRedirection:
"""Implements the `forward-redirection`.
Taken from Pytorch-lightning:
https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
"""

def __call__(self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any,
**kwargs: Any):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
original_forward = original_module.forward

def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
self.on_after_inner_forward(wrapper_module, original_module)
return out

# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]

wrapper_output = wrapper_module(*args, **kwargs)
self.on_after_outer_forward(wrapper_module, original_module)
return wrapper_output

def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass

def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass