Skip to content

Formatting and Tests for Fine Grained / Process Reward Models #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions compose_rl/data/preference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,15 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
idx (int): the index where we fetch the data in the StreamingDataset.
"""
sample = super().__getitem__(idx)
text = self._read_binary_tokenized_sample(sample, 'text')
label = self._read_binary_tokenized_sample(sample, 'label')
text = self._read_binary_tokenized_sample(sample, 'input')
label = torch.from_numpy(np.frombuffer(sample['label'], dtype=np.uint8))
# This needs to be a float tensor for BCE
label = label.to(torch.float32)

text_len = len(text)

return {
'text': text,
'label': label,
'labels': label,
'text_len': torch.Tensor([text_len]).to(torch.int64),
}
69 changes: 69 additions & 0 deletions compose_rl/metrics/reward_model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,72 @@ def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
return self.correct / self.total


class ClassificationAccuracy(Metric):
"""Classification accuracy metric.

Computes the accuracy of a classifier by comparing predictions from logits
against ground truth labels. Handles both binary and multi-class
classification.
"""

# Make torchmetrics call update only once
full_state_update = False

def __init__(
self,
binary: bool = True,
threshold: float = 0.5,
dist_sync_on_step: bool = False,
**kwargs: Any,
):
"""Initialize the metric.

Args:
binary: If True, treats as binary classification with sigmoid.
If False, treats as multi-class with softmax.
threshold: Decision threshold for binary classification
dist_sync_on_step: Synchronize metric state across processes
"""
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.binary = binary
self.threshold = threshold

self.add_state(
'correct',
default=torch.tensor(0.),
dist_reduce_fx='sum',
)
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')

def update(self, batch: dict, output_logits: torch.Tensor):
"""Update state with predictions and targets.

Args:
batch: Dictionary containing 'output_scores' and 'labels'
output_logits: `None`
"""
del output_logits
logits = batch['output_scores']
# TODO: this might break something, need to double check
targets = batch['labels'].squeeze(-1)
assert logits.shape[0] == targets.shape[0], 'Batch sizes must match'

if self.binary:
# Binary classification
probs = torch.sigmoid(logits.squeeze())
predictions = (probs > self.threshold).long()
else:
# Multi-class classification
probs = torch.softmax(logits, dim=1)
predictions = torch.argmax(probs, dim=1)

self.correct += (predictions == targets).sum().detach().cpu()
self.total += targets.shape[0]

def compute(self):
"""Compute the accuracy."""
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
return self.correct / self.total
2 changes: 2 additions & 0 deletions compose_rl/reward_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from compose_rl.reward_learning.inference_model import InferenceRewardModel
from compose_rl.reward_learning.model import (
ComposerHFClassifierRewardModel,
ComposerHFPairwiseRewardModel,
ComposerMPTPairwiseRewardModel,
)
Expand All @@ -32,6 +33,7 @@
'RewardModel',
'ComposerMPTPairwiseRewardModel',
'ComposerHFPairwiseRewardModel',
'ComposerHFClassifierRewardModel',
'InferenceRewardModel',
'BadGenerationEndReward',
'IncreasingNumbersReward',
Expand Down
85 changes: 66 additions & 19 deletions compose_rl/reward_learning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from compose_rl.reward_learning.base_reward import RewardModel, Tokenizer
from compose_rl.reward_learning.hf_utils import SequenceClassifierOutput
from compose_rl.reward_learning.model_methods import (
ClassifierRewardEnum,
PairwiseRewardEnum,
classifier_forward,
classifier_loss,
pairwise_forward,
pairwise_loss,
)
Expand Down Expand Up @@ -47,12 +50,6 @@ def __init__(
'return_logits': return_lm_logits,
}

if 'config_overrides' in kwargs:
config_overrides.update(kwargs.pop('config_overrides'))

self.min_threshold = kwargs.pop('min_threshold', None)
self.max_threshold = kwargs.pop('max_threshold', None)

super().__init__(
tokenizer=tokenizer,
use_train_metrics=use_train_metrics,
Expand All @@ -62,24 +59,14 @@ def __init__(
**kwargs,
)

def forward(
self,
batch: MutableMapping,
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
is_inference = batch.get('is_inference', False)
if is_inference:
scores = self.model(
return self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_lm_logits=self.return_lm_logits,
).scores
if self.min_threshold is not None and self.max_threshold is not None:
scores: torch.Tensor = torch.clamp(
scores,
min=self.min_threshold,
max=self.max_threshold,
)
return scores
else:
return pairwise_forward(
model=self.model,
Expand All @@ -93,7 +80,7 @@ def eval_forward(
self,
batch: MutableMapping,
outputs: Optional[SequenceClassifierOutput] = None,
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
) -> dict[str, torch.Tensor]:
return outputs if outputs is not None else self.forward(batch)

def loss(self, outputs: SequenceClassifierOutput,
Expand All @@ -105,6 +92,66 @@ def loss(self, outputs: SequenceClassifierOutput,
)


class ComposerHFClassifierRewardModel(
ComposerHFSequenceClassification,
RewardModel,
):

def __init__(
self,
tokenizer: Tokenizer,
use_train_metrics: bool = True,
additional_train_metrics: Optional[list] = None,
additional_eval_metrics: Optional[list] = None,
loss_type: str = 'bce',
return_lm_logits: bool = False,
return_last: bool = True,
**kwargs: Any,
):
self.loss_type = ClassifierRewardEnum(loss_type)
self.return_lm_logits = return_lm_logits
self.return_last = return_last

config_overrides = {
'return_logits': return_lm_logits,
}

super().__init__(
tokenizer=tokenizer,
use_train_metrics=use_train_metrics,
additional_train_metrics=additional_train_metrics,
additional_eval_metrics=additional_eval_metrics,
config_overrides=config_overrides,
**kwargs,
)

def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
ret_val = classifier_forward(
model=self.model,
tokenizer=self.tokenizer,
batch=batch,
return_last=self.return_last,
return_lm_logits=self.return_lm_logits,
)

return ret_val

def eval_forward(
self,
batch: MutableMapping,
outputs: Optional[SequenceClassifierOutput] = None,
) -> dict[str, torch.Tensor]:
return outputs if outputs is not None else self.forward(batch)

def loss(self, outputs: SequenceClassifierOutput,
batch: Mapping) -> dict[str, torch.Tensor]:
return classifier_loss(
outputs,
batch,
self.loss_type,
)


class ComposerMPTPairwiseRewardModel(ComposerMPTCausalLM, RewardModel):
"""MPT model wrapper for Pairwise/BT reward model."""

Expand Down
69 changes: 69 additions & 0 deletions compose_rl/reward_learning/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class PairwiseRewardEnum(Enum):
BELLMAN_EURUS = 'bellman_eurus'


class ClassifierRewardEnum(Enum):
BCE = 'bce'


def pairwise_forward(
model: nn.Module,
tokenizer: Tokenizer,
Expand Down Expand Up @@ -162,6 +166,40 @@ def pairwise_forward(
return outputs


def classifier_forward(
model: nn.Module,
tokenizer: Tokenizer,
batch: MutableMapping,
policy_model_config: Optional[PretrainedConfig] = None,
use_attention_sequence_id: bool = False,
return_last: bool = True,
return_lm_logits: bool = False,
) -> dict[str, torch.Tensor]:

model_output = model(
batch['text'],
attention_mask=batch['text_attention_mask'],
return_lm_logits=return_lm_logits,
)

output_scores = model_output.scores
if return_last:
# Expected Shape: (Batch Size, 1)
output_scores = torch.gather(
output_scores,
dim=1,
index=batch['text_len'].view(-1, 1) - 1,
)

# We need to add the labels here to compute metrics
outputs: dict[str, torch.Tensor] = {
'output_scores': output_scores,
'labels': batch['labels'],
}

return outputs


def pairwise_loss(
outputs: SequenceClassifierOutput,
batch: Mapping,
Expand Down Expand Up @@ -219,3 +257,34 @@ def pairwise_loss(
loss_dict['total'] = losses

return loss_dict


def classifier_loss(
outputs: SequenceClassifierOutput,
batch: Mapping,
loss_type: ClassifierRewardEnum,
) -> dict[str, torch.Tensor]:
"""Computes Classifier loss.

Given precomputed values this will compute the specified classifier loss.

Args:
outputs (SequenceClassifierOutput): Outputs from forwarding the model over the batch.
batch (Mapping): Input batch of data.
loss_type (str): Loss type that we should compute (e.g. bce),
"""
output_scores = outputs['output_scores']

if loss_type == ClassifierRewardEnum.BCE:
loss = F.binary_cross_entropy_with_logits(
output_scores,
batch['labels'],
)
else:
raise NotImplementedError(f'Loss type: {loss_type} is not supported.')

loss_dict = {
'total': loss,
}

return loss_dict
Loading