@@ -122,6 +122,10 @@ class PPOLoss(LossModule):
122
122
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
123
123
and preventing large updates. However, it will have no impact if the value estimate was done by the current
124
124
version of the value estimator. Defaults to ``None``.
125
+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
126
+
127
+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
128
+ the storages match the ones that are passed to other components, such as data collectors.
125
129
126
130
.. note::
127
131
The advantage (typically GAE) can be computed by the loss function or
@@ -341,6 +345,7 @@ def __init__(
341
345
critic : ProbabilisticTensorDictSequential = None ,
342
346
reduction : str = None ,
343
347
clip_value : float | None = None ,
348
+ device : torch .device | None = None ,
344
349
** kwargs ,
345
350
):
346
351
if actor is not None :
@@ -395,10 +400,13 @@ def __init__(
395
400
self .separate_losses = separate_losses
396
401
self .reduction = reduction
397
402
398
- try :
399
- device = next (self .parameters ()).device
400
- except (AttributeError , StopIteration ):
401
- device = getattr (torch , "get_default_device" , lambda : torch .device ("cpu" ))()
403
+ if device is None :
404
+ try :
405
+ device = next (self .parameters ()).device
406
+ except (AttributeError , StopIteration ):
407
+ device = getattr (
408
+ torch , "get_default_device" , lambda : torch .device ("cpu" )
409
+ )()
402
410
403
411
self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
404
412
if critic_coef is not None :
@@ -422,7 +430,7 @@ def __init__(
422
430
423
431
if clip_value is not None :
424
432
if isinstance (clip_value , float ):
425
- clip_value = torch .tensor (clip_value )
433
+ clip_value = torch .tensor (clip_value , device = device )
426
434
elif isinstance (clip_value , torch .Tensor ):
427
435
if clip_value .numel () != 1 :
428
436
raise ValueError (
@@ -866,6 +874,10 @@ class ClipPPOLoss(PPOLoss):
866
874
estimate was done by the current version of the value estimator. If instead ``True`` is provided, the
867
875
``clip_epsilon`` parameter will be used as the clipping threshold. If not provided or ``False``, no
868
876
clipping will be performed. Defaults to ``False``.
877
+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
878
+
879
+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
880
+ the storages match the ones that are passed to other components, such as data collectors.
869
881
870
882
.. note:
871
883
The advantage (typically GAE) can be computed by the loss function or
@@ -934,6 +946,7 @@ def __init__(
934
946
separate_losses : bool = False ,
935
947
reduction : str = None ,
936
948
clip_value : bool | float | None = None ,
949
+ device : torch .device | None = None ,
937
950
** kwargs ,
938
951
):
939
952
# Define clipping of the value loss
@@ -954,13 +967,15 @@ def __init__(
954
967
separate_losses = separate_losses ,
955
968
reduction = reduction ,
956
969
clip_value = clip_value ,
957
- ** kwargs ,
970
+ device = device ** kwargs ,
958
971
)
959
- for p in self .parameters ():
960
- device = p .device
961
- break
962
- else :
963
- device = None
972
+ if device is None :
973
+ try :
974
+ device = next (self .parameters ()).device
975
+ except (AttributeError , StopIteration ):
976
+ device = getattr (
977
+ torch , "get_default_device" , lambda : torch .device ("cpu" )
978
+ )()
964
979
self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
965
980
966
981
@property
@@ -1139,6 +1154,10 @@ class KLPENPPOLoss(PPOLoss):
1139
1154
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
1140
1155
and preventing large updates. However, it will have no impact if the value estimate was done by the current
1141
1156
version of the value estimator. Defaults to ``None``.
1157
+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
1158
+
1159
+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
1160
+ the storages match the ones that are passed to other components, such as data collectors.
1142
1161
1143
1162
.. note:
1144
1163
The advantage (typically GAE) can be computed by the loss function or
@@ -1211,6 +1230,7 @@ def __init__(
1211
1230
separate_losses : bool = False ,
1212
1231
reduction : str = None ,
1213
1232
clip_value : float | None = None ,
1233
+ device : torch .device | None = None ,
1214
1234
** kwargs ,
1215
1235
):
1216
1236
super ().__init__ (
@@ -1227,12 +1247,21 @@ def __init__(
1227
1247
separate_losses = separate_losses ,
1228
1248
reduction = reduction ,
1229
1249
clip_value = clip_value ,
1250
+ device = device ,
1230
1251
** kwargs ,
1231
1252
)
1232
1253
1254
+ if device is None :
1255
+ try :
1256
+ device = next (self .parameters ()).device
1257
+ except (AttributeError , StopIteration ):
1258
+ device = getattr (
1259
+ torch , "get_default_device" , lambda : torch .device ("cpu" )
1260
+ )()
1261
+
1233
1262
self .dtarg = dtarg
1234
1263
self ._beta_init = beta
1235
- self .register_buffer ("beta" , torch .tensor (beta ))
1264
+ self .register_buffer ("beta" , torch .tensor (beta , device = device ))
1236
1265
1237
1266
if increment < 1.0 :
1238
1267
raise ValueError (
0 commit comments