diff --git a/docs/source-fabric/advanced/model_init.rst b/docs/source-fabric/advanced/model_init.rst index f5f76e8aa087b..f7e11f2dc4210 100644 --- a/docs/source-fabric/advanced/model_init.rst +++ b/docs/source-fabric/advanced/model_init.rst @@ -69,7 +69,7 @@ When training distributed models with :doc:`FSDP/TP ` or D .. code-block:: python - # Recommended for FSDP, TP and DeepSpeed + # Recommended for FSDP and TP with fabric.init_module(empty_init=True): model = GPT3() # parameters are placed on the meta-device @@ -79,6 +79,17 @@ When training distributed models with :doc:`FSDP/TP ` or D optimizer = torch.optim.Adam(model.parameters()) optimizer = fabric.setup_optimizers(optimizer) +With DeepSpeed Stage 3, the use of :meth:`~lightning.fabric.fabric.Fabric.init_module` context manager is necessary for the model to be sharded correctly instead of attempted to be put on the GPU in its entirety. Deepspeed requires the models and optimizer to be set up jointly. + +.. code-block:: python + + # Required with DeepSpeed Stage 3 + with fabric.init_module(empty_init=True): + model = GPT3() + + optimizer = torch.optim.Adam(model.parameters()) + model, optimizer = fabric.setup(model, optimizer) + .. note:: Empty-init is experimental and the behavior may change in the future. For distributed models, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).