Skip to content

[RL] logprob compute use the same method #10596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions paddlenlp/rl/models/ppo_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,14 @@
kl_loss_coeff=self.kl_loss_coeff,
loop_chunk_size=1024,
response_start=response_start,
use_actor_fused_loss=self.entropy_coeff <= 0, # currently only support kunbo's fused head loss
use_actor_fused_loss=True, # currently only support kunbo's fused head loss
temperature=self.temperature,
)
with paddle.no_grad():
self.info_buffer["kl_loss"] = (
kl_loss.detach() / self.kl_loss_coeff if self.kl_loss_coeff > 0 else paddle.to_tensor([0.0])
)
self.info_buffer["entropy_loss"] = (
entropy_loss.detach() / self.entropy_coeff if self.entropy_coeff > 0 else paddle.to_tensor([0.0])
)
self.info_buffer["entropy_loss"] = entropy_loss.detach()

Check warning on line 495 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L495

Added line #L495 was not covered by tests
self.info_buffer["pure_policy_loss"] = (
pg_loss.detach() / self.pg_loss_coeff if self.pg_loss_coeff > 0 else paddle.to_tensor([0.0])
)
Expand Down Expand Up @@ -716,6 +714,7 @@
clip_range_score: float,
kl_loss_coeff: float, # KL loss coefficient
temperature: float,
print_entropy_loss: bool = True,
):
"""
forward function of ActorFusedLoss
Expand Down Expand Up @@ -813,11 +812,11 @@
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx]
mask_chunk = loss_mask[token_start_idx:token_end_idx]
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] * mask_chunk

Check warning on line 816 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L815-L816

Added lines #L815 - L816 were not covered by tests
if kl_loss_coeff > 0:
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx]
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx] * mask_chunk

Check warning on line 818 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L818

Added line #L818 was not covered by tests
advantages_chunk = advantages[token_start_idx:token_end_idx]
mask_chunk = loss_mask[token_start_idx:token_end_idx]

# Calculate the current logits_chunk, not fused linear
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
Expand All @@ -841,13 +840,14 @@
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
softmax_output_chunk = F.softmax(logits_chunk, axis=-1)

log_probs_chunk = -token_loss_chunk.squeeze(axis=-1)
log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) * mask_chunk

Check warning on line 843 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L843

Added line #L843 was not covered by tests
# calculate gradient, note sign
grad_logits_chunk = labels_one_hot.astype("float32") - softmax_output_chunk
grad_logits_chunk = grad_logits_chunk.astype(dtype)

# ratio
ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk)

clipped_ratio_chunk = paddle.clip(
ratio_chunk, min=1.0 - clip_range_ratio_low, max=1.0 + clip_range_ratio_high
)
Expand Down Expand Up @@ -892,6 +892,7 @@
if kl_loss_coeff > 0:
# [3] kl loss
delta_chunk = ref_log_chunk - log_probs_chunk

exp_delta_chunk = paddle.exp(delta_chunk)
kl_loss_estimate_chunk = exp_delta_chunk - delta_chunk - 1
kl_loss_clipped_chunk = (
Expand All @@ -912,6 +913,17 @@
)
d_loss_d_logits_chunk += d_kl_log_probs_chunk.unsqueeze(-1) * d_log_probs_d_logits_chunk

if print_entropy_loss:

Check warning on line 916 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L916

Added line #L916 was not covered by tests
# [2] entropy loss
log_prob_chunk = paddle.log(paddle.clip(softmax_output_chunk, min=1e-12))
entropy_loss_chunk = -(softmax_output_chunk * log_prob_chunk).sum(axis=-1) * mask_chunk

Check warning on line 919 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L918-L919

Added lines #L918 - L919 were not covered by tests
# entropy_loss_chunk shape is [bs, seqlen, vocab_size // tensor_parallel_degree], do all_reduce sum here
if tensor_parallel_degree > 1 and tensor_parallel_output:
paddle.distributed.all_reduce(

Check warning on line 922 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L921-L922

Added lines #L921 - L922 were not covered by tests
entropy_loss_chunk, op=paddle.distributed.ReduceOp.SUM, group=model_parallel_group
)
total_entropy_loss += entropy_loss_chunk.sum() / divisor

Check warning on line 925 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L925

Added line #L925 was not covered by tests

# grads
if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
Expand Down
157 changes: 157 additions & 0 deletions paddlenlp/rl/trainer/actor_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops

Check warning on line 21 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L19-L21

Added lines #L19 - L21 were not covered by tests
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy

from ..models.ppo_model_utils import (
Expand Down Expand Up @@ -57,6 +60,13 @@
Raises:
None.
"""
if self.args.use_fused_head_and_loss_fn:
return self.compute_fused_logprob(

Check warning on line 64 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L63-L64

Added lines #L63 - L64 were not covered by tests
input_ids=input_ids,
position_ids=position_ids,
**kwargs,
)

log_probs_list = []
batch_size, sequence_length = input_ids.shape
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
Expand Down Expand Up @@ -147,6 +157,153 @@

return paddle.concat(log_probs_list, axis=0)

def compute_fused_logprob(

Check warning on line 160 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L160

Added line #L160 was not covered by tests
self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, loop_chunk_size=1024, **kwargs
):
log_probs_list = []
batch_size, sequence_length = input_ids.shape
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size

Check warning on line 166 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L163-L166

Added lines #L163 - L166 were not covered by tests

# Pipe model outputs a logits tensor with LMHead, while non-pipe model
# outputs a tuple with logits tensor as the only one element.
startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id)
response_start = kwargs["prompt"].shape[-1] - 1 if "prompt" in kwargs else 0

Check warning on line 171 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L170-L171

Added lines #L170 - L171 were not covered by tests

num_embeddings = self.model.config.vocab_size
tensor_parallel_degree = self.model.config.tensor_parallel_degree
tensor_parallel_output = self.model.config.tensor_parallel_output

Check warning on line 175 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L173-L175

Added lines #L173 - L175 were not covered by tests

for i in range(num_batches):

Check warning on line 177 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L177

Added line #L177 was not covered by tests
# Calculate the start and end indices for the current batch
start_index = i * per_device_logprob_batch_size
end_index = min(start_index + per_device_logprob_batch_size, batch_size)

Check warning on line 180 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L179-L180

Added lines #L179 - L180 were not covered by tests

# Extract the current batch
current_input_ids = input_ids[start_index:end_index]
current_startend_row_indices = (

Check warning on line 184 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L183-L184

Added lines #L183 - L184 were not covered by tests
startend_row_indices[start_index:end_index] if startend_row_indices is not None else None
)
current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None
current_labels = current_input_ids[:, response_start + 1 :]

Check warning on line 188 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L187-L188

Added lines #L187 - L188 were not covered by tests

if self.args.use_remove_padding:
from ..utils.bert_padding import prepare_flashmask_inputs

Check warning on line 191 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L190-L191

Added lines #L190 - L191 were not covered by tests

update_inputs = prepare_flashmask_inputs(

Check warning on line 193 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L193

Added line #L193 was not covered by tests
current_input_ids,
current_position_ids,
self.tokenizer.pad_token_id,
self.model.config.sequence_parallel,
self.model.config.tensor_parallel_degree,
)
current_input_ids = update_inputs["input_ids"]
current_position_ids = update_inputs["position_ids"]
current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"]
indices = update_inputs["indices"]
raw_input_shape = update_inputs["raw_input_shape"]
pad_size = update_inputs["pad_size"]

Check warning on line 205 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L200-L205

Added lines #L200 - L205 were not covered by tests

# NOTE: for use_fused_head_and_loss_fn
self.model.training = True
hidden_states, lm_head_weight, lm_head_bias, transpose_y = self.model(

Check warning on line 209 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L208-L209

Added lines #L208 - L209 were not covered by tests
current_input_ids,
position_ids=current_position_ids,
attn_mask_startend_row_indices=current_startend_row_indices,
)
self.model.training = False

Check warning on line 214 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L214

Added line #L214 was not covered by tests

if self.args.use_remove_padding:
if pad_size > 0:
hidden_states = hidden_states[:, :-pad_size]

Check warning on line 218 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L216-L218

Added lines #L216 - L218 were not covered by tests

from ..utils.bert_padding import pad_input

Check warning on line 220 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L220

Added line #L220 was not covered by tests

hidden_states = pad_input(

Check warning on line 222 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L222

Added line #L222 was not covered by tests
hidden_states.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1]
).contiguous()

if self.args.use_fp32_compute and hidden_states.dtype != paddle.float32:
hidden_states = hidden_states.cast(paddle.float32)
lm_head_weight = lm_head_weight.cast(paddle.float32)
if lm_head_bias is not None:
lm_head_bias = lm_head_bias.cast(paddle.float32)

Check warning on line 230 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L226-L230

Added lines #L226 - L230 were not covered by tests

# Recover
hidden_states = hidden_states[:, response_start:-1, :]
dtype = hidden_states.dtype
original_shape = hidden_states.shape
if tensor_parallel_degree > 1:
assert tensor_parallel_output, (

Check warning on line 237 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L233-L237

Added lines #L233 - L237 were not covered by tests
"When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, "
"tensor_parallel_output needs to be set to True."
)
# Parallel Configuration
if tensor_parallel_degree > 1 and tensor_parallel_output:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
tensor_parallel_degree = hcg.get_model_parallel_world_size()

Check warning on line 245 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L242-L245

Added lines #L242 - L245 were not covered by tests

# reshape
hidden_states = hidden_states.reshape([-1, original_shape[-1]])
labels = current_labels.reshape([-1])

Check warning on line 249 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L248-L249

Added lines #L248 - L249 were not covered by tests

n_tokens = hidden_states.shape[0]
n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1]

Check warning on line 252 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L251-L252

Added lines #L251 - L252 were not covered by tests

# convert dtype of weights and biases of lm_head
lm_head_weight_cast = lm_head_weight.astype(dtype)
if lm_head_bias is not None:
lm_head_bias_cast = lm_head_bias.astype(dtype)

Check warning on line 257 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L255-L257

Added lines #L255 - L257 were not covered by tests

# use indices to distinguish the devices.
if tensor_parallel_degree > 1 and tensor_parallel_output:
rank = hcg.get_model_parallel_rank()
per_part_size = num_embeddings // tensor_parallel_degree
indices = paddle.arange(

Check warning on line 263 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L260-L263

Added lines #L260 - L263 were not covered by tests
rank * per_part_size,
rank * per_part_size + n_classes,
dtype=labels.dtype,
).unsqueeze(0)
else:
indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0)

Check warning on line 269 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L269

Added line #L269 was not covered by tests

log_prob_chunks = []
for ci in range(0, n_tokens, loop_chunk_size):
token_start_idx = ci
token_end_idx = min(ci + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]

Check warning on line 276 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L271-L276

Added lines #L271 - L276 were not covered by tests

# Calculate the current logits_chunk, not fused linear
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast

Check warning on line 281 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L279-L281

Added lines #L279 - L281 were not covered by tests

logits_chunk = logits_chunk_cast.astype("float32")
logits_chunk = logits_chunk / self.args.temperature

Check warning on line 284 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L283-L284

Added lines #L283 - L284 were not covered by tests

# rewritten as cross entropy
if tensor_parallel_degree > 1 and tensor_parallel_output:
token_loss_chunk = mp_ops._c_softmax_with_cross_entropy(

Check warning on line 288 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L287-L288

Added lines #L287 - L288 were not covered by tests
logits_chunk,
labels_chunk,
group=model_parallel_group,
return_softmax=False,
)
else:
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
log_prob_chunk = -token_loss_chunk.squeeze(axis=-1)
log_prob_chunks.append(log_prob_chunk)

Check warning on line 297 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L295-L297

Added lines #L295 - L297 were not covered by tests

log_probs = paddle.concat(log_prob_chunks, axis=-1).reshape(original_shape[:-1])
log_probs_list.append(log_probs)

Check warning on line 300 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L299-L300

Added lines #L299 - L300 were not covered by tests

log_prob_chunks = None
paddle.device.cuda.empty_cache()

Check warning on line 303 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L302-L303

Added lines #L302 - L303 were not covered by tests

return paddle.concat(log_probs_list, axis=0)

Check warning on line 305 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L305

Added line #L305 was not covered by tests

def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
# inputs shared by policy and value trainer
input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt
Expand Down
19 changes: 10 additions & 9 deletions paddlenlp/rl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,15 +1473,16 @@
if self.args.balance_batch:
batch = self._balance_batch(batch)

# step 2-3: compute logprob for rollout data
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
with reload_and_offload_scope(self, self.reference_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)

with reload_and_offload_scope(self, self.actor_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)
with self.autocast_smart_context_manager():

Check warning on line 1476 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1476

Added line #L1476 was not covered by tests
# step 2-3: compute logprob for rollout data
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
with reload_and_offload_scope(self, self.reference_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)

Check warning on line 1481 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1478-L1481

Added lines #L1478 - L1481 were not covered by tests

with reload_and_offload_scope(self, self.actor_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)

Check warning on line 1485 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1483-L1485

Added lines #L1483 - L1485 were not covered by tests

# step 2-2: compute reward for rollout data
with TimerScope(
Expand Down