Skip to content

Commit 2332909

Browse files
author
Vincent Moens
committed
[Feature] Make benchmarked losses compatible with torch.compile
ghstack-source-id: 825ded5 Pull Request resolved: #2405
1 parent 605b4aa commit 2332909

File tree

26 files changed

+940
-277
lines changed

26 files changed

+940
-277
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
127127
if [ "${CU_VERSION:-}" == cpu ] ; then
128128
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
129129
else
130-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
130+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
131131
fi
132132
elif [[ "$TORCH_VERSION" == "stable" ]]; then
133133
if [ "${CU_VERSION:-}" == cpu ] ; then
134-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
134+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
135135
else
136-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
136+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U
137137
fi
138138
else
139139
printf "Failed to install pytorch"

.github/unittest/linux_examples/scripts/run_all.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive
150150
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
151151
if [[ "$TORCH_VERSION" == "nightly" ]]; then
152152
if [ "${CU_VERSION:-}" == cpu ] ; then
153-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
153+
pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
154154
else
155-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
155+
pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
156156
fi
157157
elif [[ "$TORCH_VERSION" == "stable" ]]; then
158158
if [ "${CU_VERSION:-}" == cpu ] ; then
159-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
159+
pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
160160
else
161-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
161+
pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION
162162
fi
163163
else
164164
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_brax/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3434
fi
3535
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3636
if [ "${CU_VERSION:-}" == cpu ] ; then
37-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
37+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
3838
else
3939
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
4040
fi

.github/unittest/linux_libs/scripts_openx/install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
4141
else
42-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
42+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -U
4343
fi
4444
else
4545
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_rlhf/install.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive
3131
printf "Installing PyTorch with cu121"
3232
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3333
if [ "${CU_VERSION:-}" == cpu ] ; then
34-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
34+
pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
3535
else
36-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
36+
pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
4141
else
42-
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
42+
pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cu121
4343
fi
4444
else
4545
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_vd4rl/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
3737
fi
3838
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
4141
else
4242
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
4343
fi

.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}"
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
4040
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
4141
else
42-
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y
42+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y
4343
fi
4444

4545
# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has

.github/unittest/linux_optdeps/scripts/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
2020
git submodule sync && git submodule update --init --recursive
2121

2222
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
23-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
23+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
2424

2525
# install tensordict
2626
if [[ "$RELEASE" == 0 ]]; then

.github/workflows/benchmarks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ jobs:
4040
- name: Run benchmarks
4141
run: |
4242
cd benchmarks/
43+
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
4344
python -m pytest --benchmark-json output.json
4445
- name: Store benchmark results
4546
if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
@@ -107,6 +108,7 @@ jobs:
107108
- name: Run benchmarks
108109
run: |
109110
cd benchmarks/
111+
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
110112
python3 -m pytest --benchmark-json output.json
111113
- name: Store benchmark results
112114
uses: benchmark-action/github-action-benchmark@v1

.github/workflows/benchmarks_pr.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
- name: Run benchmarks
4747
run: |
4848
cd benchmarks/
49+
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
4950
RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
5051
git checkout ${{ github.event.pull_request.base.sha }}
5152
$RUN_BENCHMARK ${{ env.BASELINE_JSON }}
@@ -125,6 +126,7 @@ jobs:
125126
- name: Run benchmarks
126127
run: |
127128
cd benchmarks/
129+
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
128130
RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
129131
git checkout ${{ github.event.pull_request.base.sha }}
130132
$RUN_BENCHMARK ${{ env.BASELINE_JSON }}

0 commit comments

Comments
 (0)