Skip to content

Commit 003924e

Browse files
Vincent Moensvmoens
Vincent Moens
authored andcommitted
[Algorithm] GRPO scripts
ghstack-source-id: d8ddac5 Pull-Request-resolved: #2970
1 parent 023c965 commit 003924e

File tree

25 files changed

+1234
-185
lines changed

25 files changed

+1234
-185
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,4 @@ log
180180
.DS_Store
181181
Roms
182182

183-
scratch/*.py
183+
scratch/*

README.md

Lines changed: 41 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -826,85 +826,52 @@ If you're using TorchRL, please refer to this BibTeX entry to cite this work:
826826

827827
## Installation
828828

829-
Create a conda environment where the packages will be installed.
829+
### Create a new virtual environment:
830+
```bash
831+
python -m venv torchrl
832+
source torchrl/bin/activate # On Windows use: venv\Scripts\activate
833+
```
834+
835+
Or create a conda environment where the packages will be installed.
830836

831837
```
832-
conda create --name torch_rl python=3.9
833-
conda activate torch_rl
838+
conda create --name torchrl python=3.9
839+
conda activate torchrl
834840
```
835841

836-
**PyTorch**
842+
### Install dependencies:
837843

838-
Depending on the use of functorch that you want to make, you may want to
844+
#### PyTorch
845+
846+
Depending on the use of torchrl that you want to make, you may want to
839847
install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
840848
See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
841849
including `pip3` or other special installation instructions.
842850

843-
**Torchrl**
851+
TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"` etc.
852+
853+
#### Torchrl
844854

845855
You can install the **latest stable release** by using
846856
```bash
847857
pip3 install torchrl
848858
```
849-
This should work on linux, Windows 10 and OsX (Intel or Silicon chips).
850-
On certain Windows machines (Windows 11), one should install the library locally (see below).
851-
852-
For AArch64 machines, the binaries are not yet stored on PyPI so you will need to download them directly from
853-
the [release page](https://github.com/pytorch/rl/releases/) or install the library via
854-
```
855-
pip3 install git+https://github.com/pytorch/rl@v0.8.0
856-
```
859+
This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only).
860+
On certain Windows machines (Windows 11), one should build the library locally.
861+
This can be done in two ways:
857862

858-
The **nightly build** can be installed via
859-
```bash
860-
pip3 install tensordict-nightly torchrl-nightly
861-
```
862-
which we currently only ship for Linux machines.
863-
Importantly, the nightly builds require the nightly builds of PyTorch too.
864-
865-
To install extra dependencies, call
866-
```bash
867-
pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing]"
868-
```
869-
or a subset of these.
870-
871-
To install torchrl with the latest pytorch, use
872-
```bash
873-
pip3 install "torchrl[replay_buffer]"
874-
```
875-
since some features in the replay buffer require PyTorch 2.7.0 or above.
876-
877-
One may also desire to install the library locally. Three main reasons can motivate this:
878-
- the nightly/stable release isn't available for one's platform (eg, Windows 11, nightlies for Apple Silicon etc.);
879-
- contributing to the code;
880-
- install torchrl with a previous version of PyTorch (any version >= 2.1) (note that this should also be doable via a regular install followed
881-
by a downgrade to a previous pytorch version -- but the C++ binaries will not be available so some feature will not work,
882-
such as prioritized replay buffers and the like.)
883-
884-
**Disclaimer**: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not
885-
directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
886-
PyTorch to be installed and we are working hard to loosen that requirement.
887-
The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.
888-
Some features (e.g., working with nested jagged tensors) may also
889-
be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version
890-
unless there is a strong reason not to do so.
891-
892-
To install the library locally, start by cloning the repo:
893863
```bash
864+
# Install and build locally v0.8.1 of the library without cloning
865+
pip3 install git+https://github.com/pytorch/rl@v0.8.1
866+
# Clone the library and build it locally
867+
git clone https://github.com/pytorch/tensordict
894868
git clone https://github.com/pytorch/rl
895-
```
896-
and don't forget to check out the branch or tag you want to use for the build:
897-
```bash
898-
git checkout v0.8.0
869+
pip install -e tensordict
870+
pip install -e rl
899871
```
900872

901-
Go to the directory where you have cloned the torchrl repo and install it (after
902-
installing `ninja`)
903-
```bash
904-
cd /path/to/torchrl/
905-
pip3 install ninja -U
906-
python setup.py develop
907-
```
873+
Note that tensordict local build requires `cmake` to be installed via [homebrew](https://brew.sh/) (MacOS) or another package manager
874+
such as `apt`, `apt-get`, `conda` or `yum` but NOT `pip`, as well as `pip install "pybind11[global]"`.
908875

909876
One can also build the wheels to distribute to co-workers using
910877
```bash
@@ -915,22 +882,22 @@ Your wheels will be stored there `./dist/torchrl<name>.whl` and installable via
915882
pip install torchrl<name>.whl
916883
```
917884

918-
**Warning**: Unfortunately, `pip3 install -e .` does not currently work. Contributions to help fix this are welcome!
919-
920-
On M1 machines, this should work out-of-the-box with the nightly build of PyTorch.
921-
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
922-
`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears, then try
923-
924-
```
925-
ARCHFLAGS="-arch arm64" python setup.py develop
885+
The **nightly build** can be installed via
886+
```bash
887+
pip3 install tensordict-nightly torchrl-nightly
926888
```
889+
which we currently only ship for Linux machines.
890+
Importantly, the nightly builds require the nightly builds of PyTorch too.
891+
Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
927892

928-
To run a quick sanity check, leave that directory (e.g. by executing `cd ~/`)
929-
and try to import the library.
930-
```
931-
python -c "import torchrl"
932-
```
933-
This should not return any warning or error.
893+
894+
**Disclaimer**: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not
895+
directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
896+
PyTorch to be installed and we are working hard to loosen that requirement.
897+
The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.
898+
Some features (e.g., working with nested jagged tensors) may also
899+
be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version
900+
unless there is a strong reason not to do so.
934901

935902
**Optional dependencies**
936903

@@ -959,43 +926,6 @@ pip3 install tensorboard
959926
pip3 install wandb
960927
```
961928

962-
**Troubleshooting**
963-
964-
If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs (or
965-
a warning indicating that the C++ binaries could not be loaded),
966-
it means that the C++ extensions were not installed or not found.
967-
968-
- One common reason might be that you are trying to import torchrl from within the
969-
git repo location. The following code snippet should return an error if
970-
torchrl has not been installed in `develop` mode:
971-
```
972-
cd ~/path/to/rl/repo
973-
python -c 'from torchrl.envs.libs.gym import GymEnv'
974-
```
975-
If this is the case, consider executing torchrl from another location.
976-
- If you're not importing torchrl from within its repo location, it could be
977-
caused by a problem during the local installation. Check the log after the
978-
`python setup.py develop`. One common cause is a g++/C++ version discrepancy
979-
and/or a problem with the `ninja` library.
980-
- If the problem persists, feel free to open an issue on the topic in the repo,
981-
we'll make our best to help!
982-
- On **MacOs**, we recommend installing XCode first.
983-
With Apple Silicon M1 chips, make sure you are using the arm64-built python
984-
(e.g. [here](https://betterprogramming.pub/how-to-install-pytorch-on-apple-m1-series-512b3ad9bc6)).
985-
Running the following lines of code
986-
```
987-
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
988-
python collect_env.py
989-
```
990-
should display
991-
```
992-
OS: macOS *** (arm64)
993-
```
994-
and not
995-
```
996-
OS: macOS **** (x86_64)
997-
```
998-
999929
Versioning issues can cause error message of the type ```undefined symbol```
1000930
and such. For these, refer to the [versioning issues document](https://github.com/pytorch/rl/blob/main/knowledge_base/VERSIONING_ISSUES.md)
1001931
for a complete explanation and proposed workarounds.

setup.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,7 @@ def _main(argv):
195195
sys.argv = [sys.argv[0]] + unknown
196196

197197
extra_requires = {
198-
"atari": [
199-
"gym",
200-
"atari-py",
201-
"ale-py",
202-
"gym[accept-rom-license]",
203-
"pygame",
204-
],
198+
"atari": ["gymnasium[atari]"],
205199
"dm_control": ["dm_control"],
206200
"replay_buffer": ["torch>=2.7.0"],
207201
"gym_continuous": ["gymnasium<1.0", "mujoco"],

sota-implementations/grpo/README.md

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# GRPO: Generalized Reward-Conditioned Policy Optimization
2+
3+
This is an implementation of GRPO for language models, built on top of TorchRL.
4+
5+
## Overview
6+
7+
GRPO is a method for training language models using reinforcement learning, with the following key features:
8+
- Multi-GPU support with efficient device management
9+
- Mixed precision training
10+
- Gradient accumulation
11+
- Automatic checkpointing
12+
- Comprehensive logging with Weights & Biases
13+
- Hydra configuration system
14+
15+
## Installation
16+
17+
1. Install dependencies:
18+
```bash
19+
# GSM8K deps
20+
pip install -r sota-implementations/grpo/requirements_gsm8k.txt
21+
# IFEval deps
22+
pip install -r sota-implementations/grpo/requirements_ifeval.txt
23+
```
24+
25+
2. Set required environment variables:
26+
```bash
27+
export VLLM_USE_V1=0 # Required for vLLM compatibility
28+
```
29+
30+
## Hardware Requirements
31+
32+
- At least 3 CUDA-capable GPUs:
33+
- Training device(s)
34+
- vLLM inference device
35+
- Reference model device
36+
37+
Devices can be controlled via the `training_model.devices`, `inference_model.devices` and `ref_model.devices` arguments.
38+
39+
## Configuration
40+
41+
The training configuration is managed through Hydra. There are two main configuration files:
42+
- `config/grpo_gsm8k.yaml`: Default configuration for GSM8K tasks (default)
43+
- `config/grpo_ifeval.yaml`: Configuration optimized for IFEval tasks
44+
45+
## Usage
46+
47+
### Basic Training
48+
49+
```bash
50+
python grpo.py
51+
```
52+
53+
### Run with IFEval Config
54+
55+
```bash
56+
python grpo.py --config-name grpo_ifeval
57+
```
58+
59+
### Override Config Values
60+
61+
```bash
62+
# Change dataset
63+
python grpo.py env.dataset=ifeval
64+
65+
# Modify training parameters
66+
python grpo.py train.epochs=2 train.optimizer.lr=2e-5
67+
68+
# Change model
69+
python grpo.py model.name=meta-llama/Llama-2-7b-hf
70+
```
71+
72+
### Hyperparameter Sweeps
73+
74+
```bash
75+
# Learning rate sweep
76+
python grpo.py --multirun train.optimizer.lr=1e-4,1e-5,1e-6
77+
78+
# Multiple parameters
79+
python grpo.py --multirun \
80+
train.optimizer.lr=1e-4,1e-5 \
81+
policy.kl_coef=0.01,0.1
82+
```
83+
84+
## Monitoring
85+
86+
Training progress is logged to Weights & Biases with the following metrics:
87+
- Reward
88+
- Advantage
89+
- KL penalty
90+
- Sequence length
91+
- ESS (Effective Sample Size)
92+
- Loss metrics (objective, clip fraction, etc.)
93+
- Gradient norm
94+
95+
## Checkpointing
96+
97+
Checkpoints are saved every `logging.checkpoint_frequency` batches and contain:
98+
- Model state
99+
- Optimizer state
100+
- Gradient scaler state (for mixed precision)
101+
- Full configuration
102+
103+
## Debugging Out-of-memory issues
104+
105+
- vLLM: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run
106+
in parallel (`env.num_envs=N`).
107+
- KL scoring: If the KL scoring is achieved on the batch of data,
108+
reduce the number of environments (`env.num_envs=N`) run in parallel.
109+
- Training: Reduce batch size (`train.optim_batch_size`)
110+
111+
## Directory Structure
112+
113+
```
114+
sota-implementations/grpo/
115+
├── config/
116+
│ └── grpo_gsm8k.yaml # Main configuration file
117+
│ └── grpo_ifeval.yaml # config file for IFEval task
118+
├── grpo.py # Training script
119+
├── grpo_utils.py # Utility functions
120+
└── README.md # This file
121+
```
122+
123+
## Output Structure
124+
125+
Each run creates a timestamped directory under `outputs/`:
126+
```
127+
outputs/
128+
└── YYYY-MM-DD/
129+
└── HH-MM-SS/
130+
├── checkpoints/
131+
│ └── checkpoint_*.pt
132+
└── .hydra/
133+
└── config.yaml
134+
```
135+
136+
For hyperparameter sweeps, outputs are stored under `multirun/`.

0 commit comments

Comments
 (0)