diff --git a/requirements/install_all.sh b/requirements/install_all.sh index 3bcce3ffc8..6b0c9f6987 100644 --- a/requirements/install_all.sh +++ b/requirements/install_all.sh @@ -9,4 +9,6 @@ pip install timm -U pip install deepspeed -U pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U pip install liger_kernel nvitop pre-commit -U +pip install wandb +pip install math_verify==0.5.2 # flash-attn: https://github.com/Dao-AILab/flash-attention/releases diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 7dba932c53..552d49f9c3 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -217,13 +217,18 @@ def _set_default(self): def _check_grpo(self): if self.rlhf_type != 'grpo': return - from packaging import version + import trl trl_version = version.parse(trl.__version__) assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. ' 'Please update it by running: pip install -U trl') + if self.use_liger_loss: + from trl.import_utils import is_liger_kernel_available + assert is_liger_kernel_available(), ( + 'Please install/update liger-kernel by running: pip install -U liger-kernel') + if self.num_generations < 2: raise ValueError( 'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided ' diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index f224fcaf4f..366af66d1e 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1147,7 +1147,8 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs): old_kwargs = to_device(kwargs, model.device) kwargs = to_device(self._post_encode(model, old_kwargs), model.device) for k, v in old_kwargs.items(): - if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs: + if k in {'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states' + } and k not in kwargs: kwargs[k] = v if 'inputs_embeds' in kwargs: kwargs.pop('input_ids', None) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index b5d039dba8..d5e26a5a09 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -203,6 +203,8 @@ class GRPOArgumentsMixin: # dataset dataset_shuffle: Optional[bool] = True + use_liger_loss: bool = False + @dataclass class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments): diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index eca9ba382d..bdb852eaff 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -12,7 +12,7 @@ from .ppo_trainer import PPOTrainer from .reward_trainer import RewardTrainer from .rlhf_mixin import RLHFTrainerMixin - from .utils import patch_lora_merge, patch_lora_unmerge, round_robin + from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -23,7 +23,7 @@ 'ppo_trainer': ['PPOTrainer'], 'reward_trainer': ['RewardTrainer'], 'rlhf_mixin': ['RLHFTrainerMixin'], - 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin'], + 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], } import sys diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index ddcdc9b3fe..7f014eec06 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -36,7 +36,7 @@ from swift.utils import JsonlWriter, gc_collect, get_device, get_logger, is_vllm_available, is_wandb_available from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation +from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation from .vllm_client import VLLMClient del HFGRPOTrainer.__init__ @@ -179,6 +179,25 @@ def __init__(self, vllm_client = kwargs.pop('vllm_client') # for external vllm super().__init__(model, ref_model, *_args, **kwargs) + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + + self.use_liger_loss = self.args.use_liger_loss + if self.use_liger_loss: + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + self.liger_grpo_loss = LigerFusedLinearGRPOLoss( + beta=self.beta, + epsilon_low=self.epsilon_low, + epsilon_high=self.epsilon_high, + temperature=self.temperature, + use_ref_model=self.beta != 0.0, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + self._forward_redirection = _ForwardRedirection() self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self.log_completions = args.log_completions @@ -275,11 +294,6 @@ def __init__(self, self.reward_funcs[i] = self.accelerator.prepare_model( reward_func, evaluation_mode=True, device_placement=True) - # Multi-step - self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon - # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa self._step = 0 # Buffer the batch to reuse generated outputs across multiple updates. For more details, see @@ -495,7 +509,7 @@ def _move_model_to_vllm(self): if self.args.async_generate: # before sync weight, we should wait async generate finish self._wait_queue() - if self.args.use_vllm: + if self.use_vllm: llm_model = self.engine.inner_model else: llm_model = self.engine.engine.engine @@ -949,15 +963,6 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li batch_encoded_inputs['old_per_token_logps'] = ( self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None) - if self.beta == 0.0: - ref_per_token_logps = None - elif self.ref_model is not None: - ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs) - batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps - ga_batch_encoded_inputs.append(batch_encoded_inputs) return ga_batch_encoded_inputs @@ -1004,6 +1009,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if isinstance(inputs, list): assert len(inputs) == 1 inputs = inputs[0] + if self.use_liger_loss: + unwrapped_model = self.accelerator.unwrap_model(model) + return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + else: + return self._compute_loss(model, inputs) + + def _compute_loss(self, model, inputs): completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] # apply the completion_mask to exclude loss and metrics for overlong completions @@ -1017,7 +1029,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Compute the KL divergence between the model and the reference model if self.beta != 0.0: - ref_per_token_logps = inputs['ref_per_token_logps'] + with torch.no_grad(): + if self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps(self.model, inputs) + per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) @@ -1096,6 +1114,72 @@ def _get_per_token_logps(self, model, inputs): input_ids = input_ids[:, -logits_to_keep:] return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + @profiling_decorator + def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): + # unwrap the model to access the model.model + if is_peft_model(unwrapped_model): + unwrapped_model = unwrapped_model.base_model.model + if not unwrapped_model.model_meta.is_multimodal: + last_hidden_state = unwrapped_model.model( + input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).last_hidden_state + else: + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } + with self._template_context(self.template): + outputs = unwrapped_model(**inputs, output_hidden_states=True) + last_hidden_state = outputs.hidden_states[-1] + + last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) + if logits_to_keep is not None: + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + return last_hidden_state + + def compute_liger_loss(self, unwrapped_model, inputs): + # Compute the per-token log probabilities for the model + input_ids = inputs['input_ids'] + logits_to_keep = inputs['logits_to_keep'] + completion_ids = input_ids[:, -logits_to_keep:] + completion_mask = inputs['completion_mask'] + + # Compute the KL divergence between the model and the reference model + ref_per_token_logps = None + if self.beta != 0.0: + with torch.no_grad(): + if self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps(self.model, inputs) + + # get the last hidden state of the model + last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep) + # compute loss and metrics using liger grpo loss + loss, metrics = self.liger_grpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + attention_mask=completion_mask, + advantages=inputs['advantages'], + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs['old_per_token_logps'], + ref_per_token_logps=ref_per_token_logps, + ) + # Extract metrics from the liger_grpo_loss output + # KL divergence is the first metric when beta is non-zero + mean_kl = metrics[0] if self.beta != 0.0 else None + clip_ratio = metrics[-1] + + mode = 'eval' if self.control.should_evaluate else 'train' + if self.beta != 0.0: + self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + self._metrics[mode]['clip_ratio'].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) + return loss + def evaluation_loop(self, dataloader, *args, **kwargs): # Wait for the training rollout to complete if self.args.async_generate: diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 612a62dd50..7f70a957f2 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -5,6 +5,7 @@ import torch from peft.tuners import lora from peft.tuners.lora import LoraLayer +from torch import nn def round_robin(num_reqs, num_workers): @@ -157,3 +158,48 @@ def unwrap_model_for_generation( add_hooks(model) else: yield unwrapped_model + + +class _ForwardRedirection: + """Implements the `forward-redirection`. + Taken from Pytorch-lightning: + https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 + A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. + """ + + def __call__(self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, + **kwargs: Any): + """Reroutes a method call through the `wrapper_module`'s `forward` method. + Args: + wrapper_module: The module that has `original_module` wrapped. + original_module: The module that was wrapped inside `wrapper_module`. + method_name: The name of the method that should be called on the `original_module` after inputs get + redirected through the `wrapper_module`'s `forward` method. + *args: The positional arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + """ + original_forward = original_module.forward + + def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: + # Unpatch ourselves immediately before calling the method `method_name` + # because itself may want to call the real `forward` + original_module.forward = original_forward # type: ignore[method-assign] + # Call the actual method e.g. `.training_step(...)` + out = method(*_args, **_kwargs) + self.on_after_inner_forward(wrapper_module, original_module) + return out + + # Patch the original_module's forward so we can redirect the arguments back to the real method + original_module.forward = wrapped_forward # type: ignore[method-assign] + + wrapper_output = wrapper_module(*args, **kwargs) + self.on_after_outer_forward(wrapper_module, original_module) + return wrapper_output + + def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass + + def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass