Skip to content

Commit 0bc468c

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 38776a1 commit 0bc468c

File tree

8 files changed

+17
-10
lines changed

8 files changed

+17
-10
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
127127
if [ "${CU_VERSION:-}" == cpu ] ; then
128128
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
129129
else
130-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
130+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
131131
fi
132132
elif [[ "$TORCH_VERSION" == "stable" ]]; then
133133
if [ "${CU_VERSION:-}" == cpu ] ; then
134-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
134+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
135135
else
136-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
136+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U
137137
fi
138138
else
139139
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_brax/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3434
fi
3535
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3636
if [ "${CU_VERSION:-}" == cpu ] ; then
37-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
37+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
3838
else
3939
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
4040
fi

.github/unittest/linux_libs/scripts_openx/install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
4141
else
42-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
42+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -U
4343
fi
4444
else
4545
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_vd4rl/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
4141
else
4242
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
4343
fi

.github/unittest/linux_optdeps/scripts/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
2020
git submodule sync && git submodule update --init --recursive
2121

2222
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
23-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
23+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
2424

2525
# install tensordict
2626
if [[ "$RELEASE" == 0 ]]; then

test/test_cost.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
_split_and_pad_sequence,
147147
)
148148

149+
TORCH_VERSION = torch.__version__
149150

150151
# Capture all warnings
151152
pytestmark = [
@@ -15644,6 +15645,7 @@ def __init__(self):
1564415645
assert p.device == dest
1564515646

1564615647

15648+
@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5")
1564715649
def test_exploration_compile():
1564815650
m = ProbabilisticTensorDictModule(
1564915651
in_keys=["loc", "scale"],

torchrl/modules/distributions/continuous.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
import numpy as np
1313
import torch
1414
from torch import distributions as D, nn
15-
from torch.compiler import assume_constant_result
15+
16+
try:
17+
from torch.compiler import assume_constant_result
18+
except ImportError:
19+
from torch._dynamo import assume_constant_result
20+
1621
from torch.distributions import constraints
1722
from torch.distributions.transforms import _InverseTransform
1823

torchrl/objectives/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
try:
2828
from torch.compiler import is_dynamo_compiling
29-
except ImportError:
29+
except ModuleNotFoundError:
3030
from torch._dynamo import is_compiling as is_dynamo_compiling
3131

3232

0 commit comments

Comments
 (0)