diff --git a/paddlenlp/rl/models/ppo_model_utils.py b/paddlenlp/rl/models/ppo_model_utils.py index 1cb3dd7afc1f..67318c451f91 100644 --- a/paddlenlp/rl/models/ppo_model_utils.py +++ b/paddlenlp/rl/models/ppo_model_utils.py @@ -485,16 +485,14 @@ def forward( kl_loss_coeff=self.kl_loss_coeff, loop_chunk_size=1024, response_start=response_start, - use_actor_fused_loss=self.entropy_coeff <= 0, # currently only support kunbo's fused head loss + use_actor_fused_loss=True, # currently only support kunbo's fused head loss temperature=self.temperature, ) with paddle.no_grad(): self.info_buffer["kl_loss"] = ( kl_loss.detach() / self.kl_loss_coeff if self.kl_loss_coeff > 0 else paddle.to_tensor([0.0]) ) - self.info_buffer["entropy_loss"] = ( - entropy_loss.detach() / self.entropy_coeff if self.entropy_coeff > 0 else paddle.to_tensor([0.0]) - ) + self.info_buffer["entropy_loss"] = entropy_loss.detach() self.info_buffer["pure_policy_loss"] = ( pg_loss.detach() / self.pg_loss_coeff if self.pg_loss_coeff > 0 else paddle.to_tensor([0.0]) ) @@ -716,6 +714,7 @@ def forward( clip_range_score: float, kl_loss_coeff: float, # KL loss coefficient temperature: float, + print_entropy_loss: bool = True, ): """ forward function of ActorFusedLoss @@ -813,11 +812,11 @@ def forward( token_end_idx = min(i + loop_chunk_size, n_tokens) hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] labels_chunk = labels[token_start_idx:token_end_idx] - old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] + mask_chunk = loss_mask[token_start_idx:token_end_idx] + old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] * mask_chunk if kl_loss_coeff > 0: - ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx] + ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx] * mask_chunk advantages_chunk = advantages[token_start_idx:token_end_idx] - mask_chunk = loss_mask[token_start_idx:token_end_idx] # Calculate the current logits_chunk, not fused linear logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) @@ -841,13 +840,14 @@ def forward( token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none") softmax_output_chunk = F.softmax(logits_chunk, axis=-1) - log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) + log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) * mask_chunk # calculate gradient, note sign grad_logits_chunk = labels_one_hot.astype("float32") - softmax_output_chunk grad_logits_chunk = grad_logits_chunk.astype(dtype) # ratio ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk) + clipped_ratio_chunk = paddle.clip( ratio_chunk, min=1.0 - clip_range_ratio_low, max=1.0 + clip_range_ratio_high ) @@ -892,6 +892,7 @@ def forward( if kl_loss_coeff > 0: # [3] kl loss delta_chunk = ref_log_chunk - log_probs_chunk + exp_delta_chunk = paddle.exp(delta_chunk) kl_loss_estimate_chunk = exp_delta_chunk - delta_chunk - 1 kl_loss_clipped_chunk = ( @@ -912,6 +913,17 @@ def forward( ) d_loss_d_logits_chunk += d_kl_log_probs_chunk.unsqueeze(-1) * d_log_probs_d_logits_chunk + if print_entropy_loss: + # [2] entropy loss + log_prob_chunk = paddle.log(paddle.clip(softmax_output_chunk, min=1e-12)) + entropy_loss_chunk = -(softmax_output_chunk * log_prob_chunk).sum(axis=-1) * mask_chunk + # entropy_loss_chunk shape is [bs, seqlen, vocab_size // tensor_parallel_degree], do all_reduce sum here + if tensor_parallel_degree > 1 and tensor_parallel_output: + paddle.distributed.all_reduce( + entropy_loss_chunk, op=paddle.distributed.ReduceOp.SUM, group=model_parallel_group + ) + total_entropy_loss += entropy_loss_chunk.sum() / divisor + # grads if grad_hidden_states is not None: grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( diff --git a/paddlenlp/rl/trainer/actor_trainer.py b/paddlenlp/rl/trainer/actor_trainer.py index 5e56359ec3ee..d160c6a2ea9b 100644 --- a/paddlenlp/rl/trainer/actor_trainer.py +++ b/paddlenlp/rl/trainer/actor_trainer.py @@ -16,6 +16,9 @@ import numpy as np import paddle +import paddle.nn.functional as F +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy from ..models.ppo_model_utils import ( @@ -57,6 +60,13 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor Raises: None. """ + if self.args.use_fused_head_and_loss_fn: + return self.compute_fused_logprob( + input_ids=input_ids, + position_ids=position_ids, + **kwargs, + ) + log_probs_list = [] batch_size, sequence_length = input_ids.shape per_device_logprob_batch_size = self.args.per_device_logprob_batch_size @@ -147,6 +157,153 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor return paddle.concat(log_probs_list, axis=0) + def compute_fused_logprob( + self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, loop_chunk_size=1024, **kwargs + ): + log_probs_list = [] + batch_size, sequence_length = input_ids.shape + per_device_logprob_batch_size = self.args.per_device_logprob_batch_size + num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size + + # Pipe model outputs a logits tensor with LMHead, while non-pipe model + # outputs a tuple with logits tensor as the only one element. + startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id) + response_start = kwargs["prompt"].shape[-1] - 1 if "prompt" in kwargs else 0 + + num_embeddings = self.model.config.vocab_size + tensor_parallel_degree = self.model.config.tensor_parallel_degree + tensor_parallel_output = self.model.config.tensor_parallel_output + + for i in range(num_batches): + # Calculate the start and end indices for the current batch + start_index = i * per_device_logprob_batch_size + end_index = min(start_index + per_device_logprob_batch_size, batch_size) + + # Extract the current batch + current_input_ids = input_ids[start_index:end_index] + current_startend_row_indices = ( + startend_row_indices[start_index:end_index] if startend_row_indices is not None else None + ) + current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None + current_labels = current_input_ids[:, response_start + 1 :] + + if self.args.use_remove_padding: + from ..utils.bert_padding import prepare_flashmask_inputs + + update_inputs = prepare_flashmask_inputs( + current_input_ids, + current_position_ids, + self.tokenizer.pad_token_id, + self.model.config.sequence_parallel, + self.model.config.tensor_parallel_degree, + ) + current_input_ids = update_inputs["input_ids"] + current_position_ids = update_inputs["position_ids"] + current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"] + indices = update_inputs["indices"] + raw_input_shape = update_inputs["raw_input_shape"] + pad_size = update_inputs["pad_size"] + + # NOTE: for use_fused_head_and_loss_fn + self.model.training = True + hidden_states, lm_head_weight, lm_head_bias, transpose_y = self.model( + current_input_ids, + position_ids=current_position_ids, + attn_mask_startend_row_indices=current_startend_row_indices, + ) + self.model.training = False + + if self.args.use_remove_padding: + if pad_size > 0: + hidden_states = hidden_states[:, :-pad_size] + + from ..utils.bert_padding import pad_input + + hidden_states = pad_input( + hidden_states.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1] + ).contiguous() + + if self.args.use_fp32_compute and hidden_states.dtype != paddle.float32: + hidden_states = hidden_states.cast(paddle.float32) + lm_head_weight = lm_head_weight.cast(paddle.float32) + if lm_head_bias is not None: + lm_head_bias = lm_head_bias.cast(paddle.float32) + + # Recover + hidden_states = hidden_states[:, response_start:-1, :] + dtype = hidden_states.dtype + original_shape = hidden_states.shape + if tensor_parallel_degree > 1: + assert tensor_parallel_output, ( + "When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, " + "tensor_parallel_output needs to be set to True." + ) + # Parallel Configuration + if tensor_parallel_degree > 1 and tensor_parallel_output: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + + # reshape + hidden_states = hidden_states.reshape([-1, original_shape[-1]]) + labels = current_labels.reshape([-1]) + + n_tokens = hidden_states.shape[0] + n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1] + + # convert dtype of weights and biases of lm_head + lm_head_weight_cast = lm_head_weight.astype(dtype) + if lm_head_bias is not None: + lm_head_bias_cast = lm_head_bias.astype(dtype) + + # use indices to distinguish the devices. + if tensor_parallel_degree > 1 and tensor_parallel_output: + rank = hcg.get_model_parallel_rank() + per_part_size = num_embeddings // tensor_parallel_degree + indices = paddle.arange( + rank * per_part_size, + rank * per_part_size + n_classes, + dtype=labels.dtype, + ).unsqueeze(0) + else: + indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0) + + log_prob_chunks = [] + for ci in range(0, n_tokens, loop_chunk_size): + token_start_idx = ci + token_end_idx = min(ci + loop_chunk_size, n_tokens) + hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] + labels_chunk = labels[token_start_idx:token_end_idx] + + # Calculate the current logits_chunk, not fused linear + logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) + if lm_head_bias is not None: + logits_chunk_cast += lm_head_bias_cast + + logits_chunk = logits_chunk_cast.astype("float32") + logits_chunk = logits_chunk / self.args.temperature + + # rewritten as cross entropy + if tensor_parallel_degree > 1 and tensor_parallel_output: + token_loss_chunk = mp_ops._c_softmax_with_cross_entropy( + logits_chunk, + labels_chunk, + group=model_parallel_group, + return_softmax=False, + ) + else: + token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none") + log_prob_chunk = -token_loss_chunk.squeeze(axis=-1) + log_prob_chunks.append(log_prob_chunk) + + log_probs = paddle.concat(log_prob_chunks, axis=-1).reshape(original_shape[:-1]) + log_probs_list.append(log_probs) + + log_prob_chunks = None + paddle.device.cuda.empty_cache() + + return paddle.concat(log_probs_list, axis=0) + def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: # inputs shared by policy and value trainer input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt diff --git a/paddlenlp/rl/trainer/ppo_trainer.py b/paddlenlp/rl/trainer/ppo_trainer.py index 0cadacbf175f..f0228b3813d3 100644 --- a/paddlenlp/rl/trainer/ppo_trainer.py +++ b/paddlenlp/rl/trainer/ppo_trainer.py @@ -1473,15 +1473,16 @@ def train( if self.args.balance_batch: batch = self._balance_batch(batch) - # step 2-3: compute logprob for rollout data - with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB): - with reload_and_offload_scope(self, self.reference_model): - with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB): - batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch) - - with reload_and_offload_scope(self, self.actor_model): - with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB): - batch["log_probs"] = self.actor_trainer.compute_logprob(**batch) + with self.autocast_smart_context_manager(): + # step 2-3: compute logprob for rollout data + with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB): + with reload_and_offload_scope(self, self.reference_model): + with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB): + batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch) + + with reload_and_offload_scope(self, self.actor_model): + with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB): + batch["log_probs"] = self.actor_trainer.compute_logprob(**batch) # step 2-2: compute reward for rollout data with TimerScope(