Skip to content

Update documentation on pytorch multi gpu setup #2687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Guidelines for modifications:
## Contributors

* Alessandro Assirelli
* Alex Omar
* Alice Zhou
* Amr Mousa
* Andrej Orsula
Expand Down
64 changes: 51 additions & 13 deletions docs/source/features/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,54 @@ other workflows.
Multi-GPU Training
------------------

For complex reinforcement learning environments, it may be desirable to scale up training across multiple GPUs.
This is possible in Isaac Lab through the use of the
`PyTorch distributed <https://pytorch.org/docs/stable/distributed.html>`_ framework or the
`JAX distributed <https://jax.readthedocs.io/en/latest/jax.distributed.html>`_ module respectively.

In PyTorch, the :meth:`torch.distributed` API is used to launch multiple processes of training, where the number of
processes must be equal to or less than the number of GPUs available. Each process runs on
a dedicated GPU and launches its own instance of Isaac Sim and the Isaac Lab environment.
Each process collects its own rollouts during the training process and has its own copy of the policy
network. During training, gradients are aggregated across the processes and broadcasted back to the process
at the end of the epoch.

In JAX, since the ML framework doesn't automatically start multiple processes from a single program invocation,
Isaac Lab supports the following multi-GPU training frameworks:
* `Torchrun <https://docs.pytorch.org/docs/stable/elastic/run.html>`_ through `PyTorch distributed <https://pytorch.org/docs/stable/distributed.html>`_
* `JAX distributed <https://jax.readthedocs.io/en/latest/jax.distributed.html>`_

Pytorch Torchrun Implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We are using `Pytorch Torchrun <https://docs.pytorch.org/docs/stable/elastic/run.html>`_ to manage multi-GPU
training. Torchrun manages the distributed training by:

* **Process Management**: Launching one process per GPU, where each process is assigned to a specific GPU.
* **Script Execution**: Running the same training script (e.g., RL Games trainer) on each process.
* **Environment Instances**: Each process creates its own instance of the Isaac Lab environment.
* **Gradient Synchronization**: Aggregating gradients across all processes and broadcasting the synchronized
gradients back to each process after each training step.

.. tip::
Check out this `3 minute youtube video from PyTorch <https://www.youtube.com/watch?v=Cvdhwx-OBBo&list=PL_lsbAsL_o2CSuhUhJIiW0IkdT5C2wGWj&index=2>`_
to understand how Torchrun works.

The key components in this setup are:

* **Torchrun**: Handles process spawning, communication, and gradient synchronization.
* **RL Library**: The reinforcement learning library that runs the actual training algorithm.
* **Isaac Lab**: Provides the simulation environment that each process instantiates independently.

Under the hood, Torchrun uses the `DistributedDataParallel <https://docs.pytorch.org/docs/2.7/notes/ddp.html#internal-design>`_
module to manage the distributed training. When training with multiple GPUs using Torchrun, the following happens:

* Each GPU runs an independent process
* Each process executes the full training script
* Each process maintains its own:
* Isaac Lab environment instance (with *n* parallel environments)
* Policy network copy
* Experience buffer for rollout collection
* All processes synchronize only for gradient updates

For a deeper dive into how Torchrun works, checkout
`PyTorch Docs: DistributedDataParallel - Internal Design <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`_.

Jax Implementation
^^^^^^^^^^^^^^^^^^

.. tip::
JAX is only supported with the skrl library.

With JAX, we are using `skrl.utils.distributed.jax <https://skrl.readthedocs.io/en/latest/api/utils/distributed.html>`_
Since the ML framework doesn't automatically start multiple processes from a single program invocation,
the skrl library provides a module to start them.

.. image:: ../_static/multi-gpu-rl/a3c-light.svg
Expand All @@ -45,6 +80,9 @@ the skrl library provides a module to start them.

|

Running Multi-GPU Training
^^^^^^^^^^^^^^^^^^^^^^^^^^

To train with multiple GPUs, use the following command, where ``--nproc_per_node`` represents the number of available GPUs:

.. tab-set::
Expand Down