Skip to content

Commit 18eb81d

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 0bc468c commit 18eb81d

File tree

4 files changed

+73
-68
lines changed

4 files changed

+73
-68
lines changed

.github/unittest/linux_examples/scripts/run_all.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive
150150
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
151151
if [[ "$TORCH_VERSION" == "nightly" ]]; then
152152
if [ "${CU_VERSION:-}" == cpu ] ; then
153-
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
153+
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
154154
else
155-
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
155+
pip3 install --pre torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
156156
fi
157157
elif [[ "$TORCH_VERSION" == "stable" ]]; then
158158
if [ "${CU_VERSION:-}" == cpu ] ; then
159-
pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu
159+
pip3 install torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cpu
160160
else
161-
pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/$CU_VERSION
161+
pip3 install torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION
162162
fi
163163
else
164164
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_rlhf/install.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive
3131
printf "Installing PyTorch with cu121"
3232
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3333
if [ "${CU_VERSION:-}" == cpu ] ; then
34-
pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U
34+
pip3 install --pre torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
3535
else
36-
pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
36+
pip3 install --pre torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cpu
4141
else
42-
pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cu121
42+
pip3 install torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cu121
4343
fi
4444
else
4545
printf "Failed to install pytorch"

torchrl/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@
5454

5555
_THREAD_POOL_INIT = torch.get_num_threads()
5656

57+
5758
# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
5859
@property
59-
def inv(self):
60-
"""
60+
def _inv(self):
61+
"""Patched version of Transform.inv.
62+
6163
Returns the inverse :class:`Transform` of this transform.
64+
6265
This should satisfy ``t.inv.inv is t``.
6366
"""
6467
inv = None
@@ -71,11 +74,11 @@ def inv(self):
7174
return inv
7275

7376

74-
torch.distributions.transforms.Transform.inv = inv
77+
torch.distributions.transforms.Transform.inv = _inv
7578

7679

7780
@property
78-
def inv(self):
81+
def _inv(self):
7982
inv = None
8083
if self._inv is not None:
8184
inv = self._inv()
@@ -91,4 +94,4 @@ def inv(self):
9194
return inv
9295

9396

94-
ComposeTransform.inv = inv
97+
ComposeTransform.inv = _inv

0 commit comments

Comments
 (0)