Skip to content

[RL] logprob compute use the same method #10596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions paddlenlp/rl/models/ppo_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,14 @@
kl_loss_coeff=self.kl_loss_coeff,
loop_chunk_size=1024,
response_start=response_start,
use_actor_fused_loss=self.entropy_coeff <= 0, # currently only support kunbo's fused head loss
use_actor_fused_loss=True, # currently only support kunbo's fused head loss
temperature=self.temperature,
)
with paddle.no_grad():
self.info_buffer["kl_loss"] = (
kl_loss.detach() / self.kl_loss_coeff if self.kl_loss_coeff > 0 else paddle.to_tensor([0.0])
)
self.info_buffer["entropy_loss"] = (
entropy_loss.detach() / self.entropy_coeff if self.entropy_coeff > 0 else paddle.to_tensor([0.0])
)
self.info_buffer["entropy_loss"] = entropy_loss.detach()

Check warning on line 495 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L495

Added line #L495 was not covered by tests
self.info_buffer["pure_policy_loss"] = (
pg_loss.detach() / self.pg_loss_coeff if self.pg_loss_coeff > 0 else paddle.to_tensor([0.0])
)
Expand Down Expand Up @@ -716,6 +714,7 @@
clip_range_score: float,
kl_loss_coeff: float, # KL loss coefficient
temperature: float,
print_entropy_loss: bool = True,
):
"""
forward function of ActorFusedLoss
Expand Down Expand Up @@ -813,11 +812,11 @@
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx]
mask_chunk = loss_mask[token_start_idx:token_end_idx]
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] * mask_chunk

Check warning on line 816 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L815-L816

Added lines #L815 - L816 were not covered by tests
if kl_loss_coeff > 0:
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx]
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx] * mask_chunk

Check warning on line 818 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L818

Added line #L818 was not covered by tests
advantages_chunk = advantages[token_start_idx:token_end_idx]
mask_chunk = loss_mask[token_start_idx:token_end_idx]

# Calculate the current logits_chunk, not fused linear
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
Expand All @@ -841,13 +840,14 @@
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
softmax_output_chunk = F.softmax(logits_chunk, axis=-1)

log_probs_chunk = -token_loss_chunk.squeeze(axis=-1)
log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) * mask_chunk

Check warning on line 843 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L843

Added line #L843 was not covered by tests
# calculate gradient, note sign
grad_logits_chunk = labels_one_hot.astype("float32") - softmax_output_chunk
grad_logits_chunk = grad_logits_chunk.astype(dtype)

# ratio
ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk)

clipped_ratio_chunk = paddle.clip(
ratio_chunk, min=1.0 - clip_range_ratio_low, max=1.0 + clip_range_ratio_high
)
Expand Down Expand Up @@ -892,6 +892,7 @@
if kl_loss_coeff > 0:
# [3] kl loss
delta_chunk = ref_log_chunk - log_probs_chunk

exp_delta_chunk = paddle.exp(delta_chunk)
kl_loss_estimate_chunk = exp_delta_chunk - delta_chunk - 1
kl_loss_clipped_chunk = (
Expand All @@ -912,6 +913,17 @@
)
d_loss_d_logits_chunk += d_kl_log_probs_chunk.unsqueeze(-1) * d_log_probs_d_logits_chunk

if print_entropy_loss:

Check warning on line 916 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L916

Added line #L916 was not covered by tests
# [2] entropy loss
log_prob_chunk = paddle.log(paddle.clip(softmax_output_chunk, min=1e-12))
entropy_loss_chunk = -(softmax_output_chunk * log_prob_chunk).sum(axis=-1) * mask_chunk

Check warning on line 919 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L918-L919

Added lines #L918 - L919 were not covered by tests
# entropy_loss_chunk shape is [bs, seqlen, vocab_size // tensor_parallel_degree], do all_reduce sum here
if tensor_parallel_degree > 1 and tensor_parallel_output:
paddle.distributed.all_reduce(

Check warning on line 922 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L921-L922

Added lines #L921 - L922 were not covered by tests
entropy_loss_chunk, op=paddle.distributed.ReduceOp.SUM, group=model_parallel_group
)
total_entropy_loss += entropy_loss_chunk.sum() / divisor

Check warning on line 925 in paddlenlp/rl/models/ppo_model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/models/ppo_model_utils.py#L925

Added line #L925 was not covered by tests

# grads
if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
Expand Down
155 changes: 155 additions & 0 deletions paddlenlp/rl/trainer/actor_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops

Check warning on line 21 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L19-L21

Added lines #L19 - L21 were not covered by tests
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy

from ..models.ppo_model_utils import (
Expand Down Expand Up @@ -57,6 +60,13 @@
Raises:
None.
"""
if self.args.use_fused_head_and_loss_fn:
return self.compute_fused_logprob(

Check warning on line 64 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L63-L64

Added lines #L63 - L64 were not covered by tests
input_ids=input_ids,
position_ids=position_ids,
**kwargs,
)

log_probs_list = []
batch_size, sequence_length = input_ids.shape
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
Expand Down Expand Up @@ -147,6 +157,151 @@

return paddle.concat(log_probs_list, axis=0)

def compute_fused_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs):
log_probs_list = []
batch_size, sequence_length = input_ids.shape
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size

Check warning on line 164 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L160-L164

Added lines #L160 - L164 were not covered by tests

# Pipe model outputs a logits tensor with LMHead, while non-pipe model
# outputs a tuple with logits tensor as the only one element.
startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id)
response_start = kwargs["prompt"].shape[-1] - 1 if "prompt" in kwargs else 0

Check warning on line 169 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L168-L169

Added lines #L168 - L169 were not covered by tests

for i in range(num_batches):

Check warning on line 171 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L171

Added line #L171 was not covered by tests
# Calculate the start and end indices for the current batch
start_index = i * per_device_logprob_batch_size
end_index = min(start_index + per_device_logprob_batch_size, batch_size)

Check warning on line 174 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L173-L174

Added lines #L173 - L174 were not covered by tests

# Extract the current batch
current_input_ids = input_ids[start_index:end_index]
current_startend_row_indices = (

Check warning on line 178 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L177-L178

Added lines #L177 - L178 were not covered by tests
startend_row_indices[start_index:end_index] if startend_row_indices is not None else None
)
current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None
current_labels = current_input_ids[:, response_start + 1 :]

Check warning on line 182 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L181-L182

Added lines #L181 - L182 were not covered by tests

if self.args.use_remove_padding:
from ..utils.bert_padding import prepare_flashmask_inputs

Check warning on line 185 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L184-L185

Added lines #L184 - L185 were not covered by tests

update_inputs = prepare_flashmask_inputs(

Check warning on line 187 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L187

Added line #L187 was not covered by tests
current_input_ids,
current_position_ids,
self.tokenizer.pad_token_id,
self.model.config.sequence_parallel,
self.model.config.tensor_parallel_degree,
)
current_input_ids = update_inputs["input_ids"]
current_position_ids = update_inputs["position_ids"]
current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"]
indices = update_inputs["indices"]
raw_input_shape = update_inputs["raw_input_shape"]
pad_size = update_inputs["pad_size"]

Check warning on line 199 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L194-L199

Added lines #L194 - L199 were not covered by tests

# NOTE: for use_fused_head_and_loss_fn
self.model.training = True
hidden_states, lm_head_weight, lm_head_bias, transpose_y = self.model(

Check warning on line 203 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L202-L203

Added lines #L202 - L203 were not covered by tests
current_input_ids,
position_ids=current_position_ids,
attn_mask_startend_row_indices=current_startend_row_indices,
)
self.model.training = False

Check warning on line 208 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L208

Added line #L208 was not covered by tests

if self.args.use_remove_padding:
if pad_size > 0:
hidden_states = hidden_states[:, :-pad_size]

Check warning on line 212 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L210-L212

Added lines #L210 - L212 were not covered by tests

from ..utils.bert_padding import pad_input

Check warning on line 214 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L214

Added line #L214 was not covered by tests

hidden_states = pad_input(

Check warning on line 216 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L216

Added line #L216 was not covered by tests
hidden_states.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1]
).contiguous()

if self.args.use_fp32_compute and hidden_states.dtype != paddle.float32:
hidden_states = hidden_states.cast(paddle.float32)
lm_head_weight = lm_head_weight.cast(paddle.float32)
if lm_head_bias is not None:
lm_head_bias = lm_head_bias.cast(paddle.float32)

Check warning on line 224 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L220-L224

Added lines #L220 - L224 were not covered by tests

# Recover
hidden_states = hidden_states[:, response_start:-1, :]
dtype = hidden_states.dtype
original_shape = hidden_states.shape
num_embeddings = self.model.config.vocab_size
loop_chunk_size = 1024
tensor_parallel_degree = self.model.config.tensor_parallel_degree
tensor_parallel_output = self.model.config.tensor_parallel_output
if tensor_parallel_degree > 1:
assert tensor_parallel_output, (

Check warning on line 235 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L227-L235

Added lines #L227 - L235 were not covered by tests
"When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, "
"tensor_parallel_output needs to be set to True."
)
# Parallel Configuration
if tensor_parallel_degree > 1 and tensor_parallel_output:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
tensor_parallel_degree = hcg.get_model_parallel_world_size()

Check warning on line 243 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L240-L243

Added lines #L240 - L243 were not covered by tests

# reshape
hidden_states = hidden_states.reshape([-1, original_shape[-1]])
labels = current_labels.reshape([-1])

Check warning on line 247 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L246-L247

Added lines #L246 - L247 were not covered by tests

n_tokens = hidden_states.shape[0]
n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1]

Check warning on line 250 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L249-L250

Added lines #L249 - L250 were not covered by tests

# convert dtype of weights and biases of lm_head
lm_head_weight_cast = lm_head_weight.astype(dtype)
if lm_head_bias is not None:
lm_head_bias_cast = lm_head_bias.astype(dtype)

Check warning on line 255 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L253-L255

Added lines #L253 - L255 were not covered by tests

# use indices to distinguish the devices.
if tensor_parallel_degree > 1 and tensor_parallel_output:
rank = hcg.get_model_parallel_rank()
per_part_size = num_embeddings // tensor_parallel_degree
indices = paddle.arange(

Check warning on line 261 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L258-L261

Added lines #L258 - L261 were not covered by tests
rank * per_part_size,
rank * per_part_size + n_classes,
dtype=labels.dtype,
).unsqueeze(0)
else:
indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0)

Check warning on line 267 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L267

Added line #L267 was not covered by tests

log_prob_chunks = []
for ci in range(0, n_tokens, loop_chunk_size):
token_start_idx = ci
token_end_idx = min(ci + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]

Check warning on line 274 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L269-L274

Added lines #L269 - L274 were not covered by tests

# Calculate the current logits_chunk, not fused linear
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast

Check warning on line 279 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L277-L279

Added lines #L277 - L279 were not covered by tests

logits_chunk = logits_chunk_cast.astype("float32")
logits_chunk = logits_chunk / self.args.temperature

Check warning on line 282 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L281-L282

Added lines #L281 - L282 were not covered by tests

# rewritten as cross entropy
if tensor_parallel_degree > 1 and tensor_parallel_output:
token_loss_chunk = mp_ops._c_softmax_with_cross_entropy(

Check warning on line 286 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L285-L286

Added lines #L285 - L286 were not covered by tests
logits_chunk,
labels_chunk,
group=model_parallel_group,
return_softmax=False,
)
else:
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
log_prob_chunk = -token_loss_chunk.squeeze(axis=-1)
log_prob_chunks.append(log_prob_chunk)

Check warning on line 295 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L293-L295

Added lines #L293 - L295 were not covered by tests

log_probs = paddle.concat(log_prob_chunks, axis=-1).reshape(original_shape[:-1])
log_probs_list.append(log_probs)

Check warning on line 298 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L297-L298

Added lines #L297 - L298 were not covered by tests

log_prob_chunks = None
paddle.device.cuda.empty_cache()

Check warning on line 301 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L300-L301

Added lines #L300 - L301 were not covered by tests

return paddle.concat(log_probs_list, axis=0)

Check warning on line 303 in paddlenlp/rl/trainer/actor_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/actor_trainer.py#L303

Added line #L303 was not covered by tests

def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
# inputs shared by policy and value trainer
input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt
Expand Down
19 changes: 10 additions & 9 deletions paddlenlp/rl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,15 +1473,16 @@
if self.args.balance_batch:
batch = self._balance_batch(batch)

# step 2-3: compute logprob for rollout data
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
with reload_and_offload_scope(self, self.reference_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)

with reload_and_offload_scope(self, self.actor_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)
with self.autocast_smart_context_manager():

Check warning on line 1476 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1476

Added line #L1476 was not covered by tests
# step 2-3: compute logprob for rollout data
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
with reload_and_offload_scope(self, self.reference_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)

Check warning on line 1481 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1478-L1481

Added lines #L1478 - L1481 were not covered by tests

with reload_and_offload_scope(self, self.actor_model):
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)

Check warning on line 1485 in paddlenlp/rl/trainer/ppo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/rl/trainer/ppo_trainer.py#L1483-L1485

Added lines #L1483 - L1485 were not covered by tests

# step 2-2: compute reward for rollout data
with TimerScope(
Expand Down