File tree Expand file tree Collapse file tree 4 files changed +21
-12
lines changed Expand file tree Collapse file tree 4 files changed +21
-12
lines changed Original file line number Diff line number Diff line change @@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive
150
150
printf " Installing PyTorch with %s\n" " ${CU_VERSION} "
151
151
if [[ " $TORCH_VERSION " == " nightly" ]]; then
152
152
if [ " ${CU_VERSION:- } " == cpu ] ; then
153
- pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
153
+ pip3 install --pre torch torchvision numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
154
154
else
155
- pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
155
+ pip3 install --pre torch torchvision numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
156
156
fi
157
157
elif [[ " $TORCH_VERSION " == " stable" ]]; then
158
158
if [ " ${CU_VERSION:- } " == cpu ] ; then
159
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
159
+ pip3 install torch torchvision numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/cpu
160
160
else
161
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
161
+ pip3 install torch torchvision numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/$CU_VERSION
162
162
fi
163
163
else
164
164
printf " Failed to install pytorch"
Original file line number Diff line number Diff line change @@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive
31
31
printf " Installing PyTorch with cu121"
32
32
if [[ " $TORCH_VERSION " == " nightly" ]]; then
33
33
if [ " ${CU_VERSION:- } " == cpu ] ; then
34
- pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
34
+ pip3 install --pre torch numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
35
35
else
36
- pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
36
+ pip3 install --pre torch numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
37
37
fi
38
38
elif [[ " $TORCH_VERSION " == " stable" ]]; then
39
39
if [ " ${CU_VERSION:- } " == cpu ] ; then
40
- pip3 install torch --index-url https://download.pytorch.org/whl/cpu
40
+ pip3 install torch numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/cpu
41
41
else
42
- pip3 install torch --index-url https://download.pytorch.org/whl/cu121
42
+ pip3 install torch numpy==1.26.4 numpy-base < 2.0 --index-url https://download.pytorch.org/whl/cu121
43
43
fi
44
44
else
45
45
printf " Failed to install pytorch"
Original file line number Diff line number Diff line change 12
12
from dataclasses import dataclass
13
13
from typing import Iterator , List , Optional , Tuple
14
14
15
- import torch .compiler
16
15
from tensordict import is_tensor_collection , TensorDict , TensorDictBase
17
16
18
17
from tensordict .nn import TensorDictModule , TensorDictModuleBase , TensorDictParams
25
24
from torchrl .objectives .utils import RANDOM_MODULE_LIST , ValueEstimators
26
25
from torchrl .objectives .value import ValueEstimatorBase
27
26
27
+ try :
28
+ from torch .compiler import is_dynamo_compiling
29
+ except ImportError :
30
+ from torch ._dynamo import is_compiling as is_dynamo_compiling
31
+
28
32
29
33
def _updater_check_forward_prehook (module , * args , ** kwargs ):
30
34
if (
31
35
not all (module ._has_update_associated .values ())
32
36
and RL_WARNINGS
33
- and not torch . compiler . is_dynamo_compiling ()
37
+ and not is_dynamo_compiling ()
34
38
):
35
39
warnings .warn (
36
40
module .TARGET_NET_WARNING ,
@@ -425,7 +429,7 @@ def __getattr__(self, item):
425
429
elif (
426
430
not self ._has_update_associated [item [7 :- 7 ]]
427
431
and RL_WARNINGS
428
- and not torch . compiler . is_dynamo_compiling ()
432
+ and not is_dynamo_compiling ()
429
433
):
430
434
# no updater associated
431
435
warnings .warn (
Original file line number Diff line number Diff line change @@ -804,7 +804,12 @@ def __init__(
804
804
clip_value = clip_value ,
805
805
** kwargs ,
806
806
)
807
- self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon ))
807
+ for p in self .parameters ():
808
+ device = p .device
809
+ break
810
+ else :
811
+ device = None
812
+ self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
808
813
809
814
@property
810
815
def _clip_bounds (self ):
You can’t perform that action at this time.
0 commit comments