Skip to content

Commit dd52226

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent cb0b85e commit dd52226

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

.github/unittest/linux_examples/scripts/run_all.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive
150150
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
151151
if [[ "$TORCH_VERSION" == "nightly" ]]; then
152152
if [ "${CU_VERSION:-}" == cpu ] ; then
153-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
153+
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
154154
else
155-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
155+
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
156156
fi
157157
elif [[ "$TORCH_VERSION" == "stable" ]]; then
158158
if [ "${CU_VERSION:-}" == cpu ] ; then
159-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
159+
pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu
160160
else
161-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
161+
pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/$CU_VERSION
162162
fi
163163
else
164164
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_rlhf/install.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive
3131
printf "Installing PyTorch with cu121"
3232
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3333
if [ "${CU_VERSION:-}" == cpu ] ; then
34-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
34+
pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
3535
else
36-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
36+
pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu
4141
else
42-
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
42+
pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cu121
4343
fi
4444
else
4545
printf "Failed to install pytorch"

torchrl/objectives/common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dataclasses import dataclass
1313
from typing import Iterator, List, Optional, Tuple
1414

15-
import torch.compiler
1615
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
1716

1817
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
@@ -25,12 +24,17 @@
2524
from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators
2625
from torchrl.objectives.value import ValueEstimatorBase
2726

27+
try:
28+
from torch.compiler import is_dynamo_compiling
29+
except ImportError:
30+
from torch._dynamo import is_compiling as is_dynamo_compiling
31+
2832

2933
def _updater_check_forward_prehook(module, *args, **kwargs):
3034
if (
3135
not all(module._has_update_associated.values())
3236
and RL_WARNINGS
33-
and not torch.compiler.is_dynamo_compiling()
37+
and not is_dynamo_compiling()
3438
):
3539
warnings.warn(
3640
module.TARGET_NET_WARNING,
@@ -425,7 +429,7 @@ def __getattr__(self, item):
425429
elif (
426430
not self._has_update_associated[item[7:-7]]
427431
and RL_WARNINGS
428-
and not torch.compiler.is_dynamo_compiling()
432+
and not is_dynamo_compiling()
429433
):
430434
# no updater associated
431435
warnings.warn(

torchrl/objectives/ppo.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,12 @@ def __init__(
804804
clip_value=clip_value,
805805
**kwargs,
806806
)
807-
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
807+
for p in self.parameters():
808+
device = p.device
809+
break
810+
else:
811+
device = None
812+
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
808813

809814
@property
810815
def _clip_bounds(self):

0 commit comments

Comments
 (0)