Skip to content

Commit 6e4b550

Browse files
author
Vincent Moens
committed
[BugFix] device in args of PPO losses
ghstack-source-id: d5118b5 Pull-Request-resolved: #2969
1 parent dc41223 commit 6e4b550

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class GRPOLoss(ClipPPOLoss):
7676
estimate was done by the current version of the value estimator. If instead ``True`` is provided, the
7777
``clip_epsilon`` parameter will be used as the clipping threshold. If not provided or ``False``, no
7878
clipping will be performed. Defaults to ``False``.
79+
device (torch.device, optional): device of the buffers. Defaults to ``None``.
80+
81+
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
82+
the storages match the ones that are passed to other components, such as data collectors.
7983
"""
8084

8185
actor_network: TensorDictModule
@@ -99,6 +103,7 @@ def __init__(
99103
reduction: str = None,
100104
clip_value: bool | float | None = None,
101105
kl_to_ref_coeff: float | None = None,
106+
device: torch.device = None,
102107
**kwargs,
103108
):
104109
# Define clipping of the value loss
@@ -116,6 +121,7 @@ def __init__(
116121
reduction=reduction,
117122
clip_value=clip_value,
118123
functional=False,
124+
device=device,
119125
**kwargs,
120126
)
121127
# We don't want to use the string action but the tokens

torchrl/objectives/ppo.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ class PPOLoss(LossModule):
122122
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
123123
and preventing large updates. However, it will have no impact if the value estimate was done by the current
124124
version of the value estimator. Defaults to ``None``.
125+
device (torch.device, optional): device of the buffers. Defaults to ``None``.
126+
127+
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
128+
the storages match the ones that are passed to other components, such as data collectors.
125129
126130
.. note::
127131
The advantage (typically GAE) can be computed by the loss function or
@@ -341,6 +345,7 @@ def __init__(
341345
critic: ProbabilisticTensorDictSequential = None,
342346
reduction: str = None,
343347
clip_value: float | None = None,
348+
device: torch.device | None = None,
344349
**kwargs,
345350
):
346351
if actor is not None:
@@ -395,10 +400,13 @@ def __init__(
395400
self.separate_losses = separate_losses
396401
self.reduction = reduction
397402

398-
try:
399-
device = next(self.parameters()).device
400-
except (AttributeError, StopIteration):
401-
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
403+
if device is None:
404+
try:
405+
device = next(self.parameters()).device
406+
except (AttributeError, StopIteration):
407+
device = getattr(
408+
torch, "get_default_device", lambda: torch.device("cpu")
409+
)()
402410

403411
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
404412
if critic_coef is not None:
@@ -422,7 +430,7 @@ def __init__(
422430

423431
if clip_value is not None:
424432
if isinstance(clip_value, float):
425-
clip_value = torch.tensor(clip_value)
433+
clip_value = torch.tensor(clip_value, device=device)
426434
elif isinstance(clip_value, torch.Tensor):
427435
if clip_value.numel() != 1:
428436
raise ValueError(
@@ -866,6 +874,10 @@ class ClipPPOLoss(PPOLoss):
866874
estimate was done by the current version of the value estimator. If instead ``True`` is provided, the
867875
``clip_epsilon`` parameter will be used as the clipping threshold. If not provided or ``False``, no
868876
clipping will be performed. Defaults to ``False``.
877+
device (torch.device, optional): device of the buffers. Defaults to ``None``.
878+
879+
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
880+
the storages match the ones that are passed to other components, such as data collectors.
869881
870882
.. note:
871883
The advantage (typically GAE) can be computed by the loss function or
@@ -934,6 +946,7 @@ def __init__(
934946
separate_losses: bool = False,
935947
reduction: str = None,
936948
clip_value: bool | float | None = None,
949+
device: torch.device | None = None,
937950
**kwargs,
938951
):
939952
# Define clipping of the value loss
@@ -954,13 +967,15 @@ def __init__(
954967
separate_losses=separate_losses,
955968
reduction=reduction,
956969
clip_value=clip_value,
957-
**kwargs,
970+
device=device**kwargs,
958971
)
959-
for p in self.parameters():
960-
device = p.device
961-
break
962-
else:
963-
device = None
972+
if device is None:
973+
try:
974+
device = next(self.parameters()).device
975+
except (AttributeError, StopIteration):
976+
device = getattr(
977+
torch, "get_default_device", lambda: torch.device("cpu")
978+
)()
964979
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
965980

966981
@property
@@ -1139,6 +1154,10 @@ class KLPENPPOLoss(PPOLoss):
11391154
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
11401155
and preventing large updates. However, it will have no impact if the value estimate was done by the current
11411156
version of the value estimator. Defaults to ``None``.
1157+
device (torch.device, optional): device of the buffers. Defaults to ``None``.
1158+
1159+
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
1160+
the storages match the ones that are passed to other components, such as data collectors.
11421161
11431162
.. note:
11441163
The advantage (typically GAE) can be computed by the loss function or
@@ -1211,6 +1230,7 @@ def __init__(
12111230
separate_losses: bool = False,
12121231
reduction: str = None,
12131232
clip_value: float | None = None,
1233+
device: torch.device | None = None,
12141234
**kwargs,
12151235
):
12161236
super().__init__(
@@ -1227,12 +1247,21 @@ def __init__(
12271247
separate_losses=separate_losses,
12281248
reduction=reduction,
12291249
clip_value=clip_value,
1250+
device=device,
12301251
**kwargs,
12311252
)
12321253

1254+
if device is None:
1255+
try:
1256+
device = next(self.parameters()).device
1257+
except (AttributeError, StopIteration):
1258+
device = getattr(
1259+
torch, "get_default_device", lambda: torch.device("cpu")
1260+
)()
1261+
12331262
self.dtarg = dtarg
12341263
self._beta_init = beta
1235-
self.register_buffer("beta", torch.tensor(beta))
1264+
self.register_buffer("beta", torch.tensor(beta, device=device))
12361265

12371266
if increment < 1.0:
12381267
raise ValueError(

0 commit comments

Comments
 (0)