Skip to content

Commit dc41223

Browse files
author
Vincent Moens
committed
[Feature] Losses (GRPO)
ghstack-source-id: e053882 Pull-Request-resolved: #2968
1 parent 1ba8c84 commit dc41223

File tree

4 files changed

+507
-0
lines changed

4 files changed

+507
-0
lines changed

docs/source/reference/llms.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,21 @@ Utils
129129
make_vllm_worker
130130
stateless_init_process_group
131131
vLLMWorker
132+
133+
Objectives
134+
----------
135+
136+
LLM post training require some appropriate versions of the losses implemented in TorchRL.
137+
138+
GRPO
139+
~~~~
140+
141+
.. currentmodule:: torchrl.objectives.llm
142+
143+
.. autosummary::
144+
:toctree: generated/
145+
:template: rl_template.rst
146+
147+
GRPOLoss
148+
GRPOLossOutput
149+
MCAdvantage

test/llm/test_objectives.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
import importlib.util
9+
10+
import numpy as np
11+
import pytest
12+
import torch
13+
from mocking_classes import DummyStrDataLoader
14+
15+
from tensordict import lazy_stack, set_capture_non_tensor_stack, TensorDict
16+
from torchrl.data import LazyStackStorage, ReplayBuffer, Unbounded
17+
from torchrl.envs import Transform
18+
from torchrl.envs.llm import LLMEnv
19+
from torchrl.modules.llm import TransformersWrapper
20+
from torchrl.objectives import ClipPPOLoss
21+
from torchrl.objectives.llm.grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
22+
23+
_has_transformers = importlib.util.find_spec("transformers") is not None
24+
prompts = [
25+
"Lorem ipsum dolor sit amet,",
26+
"consectetur adipiscing elit,",
27+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
28+
"Ut enim ad minim veniam, quis nostrud exercitation",
29+
"ullamco laboris nisi ut aliquip ex ea commodo consequat.",
30+
"Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore",
31+
"eu fugiat nulla pariatur.",
32+
]
33+
34+
35+
@pytest.mark.parametrize("ndim", [1, 2])
36+
def test_mc_advantage(ndim):
37+
# make trajectories
38+
def make_silly_trajectory(n_steps=None):
39+
while True:
40+
if n_steps is None:
41+
n_steps = torch.randint(low=2, high=100, size=(1,)).item()
42+
tds = []
43+
for _ in range(n_steps):
44+
n_tokens = torch.randint(low=1, high=100, size=(1,)).item()
45+
rewards = [torch.randn(n_tokens, 1)]
46+
prompt = np.random.choice(prompts)
47+
td = TensorDict(
48+
text=prompt,
49+
next=TensorDict(
50+
reward=rewards, done=torch.zeros(1, dtype=torch.bool)
51+
),
52+
)
53+
tds.append(td)
54+
tds[-1]["next", "done"] = torch.ones(1, dtype=torch.bool)
55+
yield lazy_stack(tds)
56+
57+
rb = ReplayBuffer(storage=LazyStackStorage(100))
58+
rb.append_transform(MCAdvantage(grpo_size=4))
59+
if ndim == 1:
60+
gen = make_silly_trajectory()
61+
for _ in range(100):
62+
trajectory = next(gen)
63+
rb.extend(trajectory)
64+
assert len(rb)
65+
s = rb.sample(1)
66+
assert "advantage" in s.keys()
67+
else:
68+
gen = make_silly_trajectory(n_steps=5)
69+
for _ in range(100):
70+
trajectory = lazy_stack([next(gen) for _ in range(3)])
71+
trajectory = trajectory.view(-1)
72+
rb.extend(trajectory)
73+
assert len(rb)
74+
s = rb.sample(1)
75+
assert "advantage" in s.keys()
76+
77+
78+
def test_grpo():
79+
...
80+
81+
82+
class TestPPO4LLMs:
83+
@pytest.mark.skipif(
84+
not _has_transformers, reason="transformers lib required to test PPO with LLMs"
85+
)
86+
@set_capture_non_tensor_stack(False)
87+
@pytest.mark.parametrize("from_text", [True, False])
88+
@pytest.mark.parametrize("cls", [ClipPPOLoss, GRPOLoss])
89+
def test_hf(self, from_text, cls):
90+
from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
91+
92+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
93+
tokenizer.pad_token = tokenizer.eos_token
94+
95+
model = OPTForCausalLM(OPTConfig()).eval()
96+
policy_inference = TransformersWrapper(
97+
model,
98+
tokenizer=tokenizer,
99+
generate=True,
100+
from_text=from_text,
101+
return_log_probs=True,
102+
)
103+
policy_train = TransformersWrapper(
104+
model, tokenizer=tokenizer, generate=False, from_text=False
105+
)
106+
for p in policy_train.parameters():
107+
assert p.requires_grad
108+
# Create some fake data
109+
dl = DummyStrDataLoader(batch_size=32)
110+
llm_env = LLMEnv.from_dataloader(
111+
dl,
112+
tokenizer=tokenizer if not from_text else None,
113+
batch_size=(32,),
114+
from_text=True,
115+
eos_token_id=tokenizer.eos_token_id,
116+
)
117+
118+
class RewardTransform(Transform):
119+
def _step(self, td, next_td):
120+
next_td["reward"] = torch.randn_like(
121+
td["tokens_response"], dtype=torch.float
122+
).unsqueeze(-1)
123+
return next_td
124+
125+
def transform_reward_spec(self, reward_spec):
126+
return reward_spec.set(
127+
"reward", Unbounded((*reward_spec.shape, -1, 1), dtype=torch.float)
128+
)
129+
130+
llm_env = llm_env.append_transform(RewardTransform())
131+
with torch.no_grad():
132+
data = llm_env.rollout(3, policy_inference)
133+
data = data.view(-1)
134+
assert data["tokens_response"].shape[-1] == 20
135+
# Make some fake advantages:
136+
data["advantage"] = torch.randn_like(data["next", "reward"])
137+
138+
loss = cls(
139+
actor_network=policy_train,
140+
)
141+
loss_vals = loss(data)
142+
if cls is ClipPPOLoss:
143+
assert "loss_objective" in loss_vals
144+
assert "loss_entropy" in loss_vals
145+
assert loss_vals["loss_objective"].requires_grad
146+
assert loss_vals["loss_entropy"].requires_grad
147+
assert "clip_fraction" in loss_vals
148+
assert "kl_approx" in loss_vals
149+
assert "entropy" in loss_vals
150+
assert "ESS" in loss_vals
151+
assert "loss_critic" not in loss_vals
152+
else:
153+
assert isinstance(loss_vals, GRPOLossOutput)
154+
155+
156+
if __name__ == "__main__":
157+
args, unknown = argparse.ArgumentParser().parse_known_args()
158+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/objectives/llm/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
8+
9+
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"]

0 commit comments

Comments
 (0)