Skip to content

Commit 3518f9e

Browse files
authored
Delay DeepSpeed config setup (#19209)
1 parent 91ef190 commit 3518f9e

File tree

4 files changed

+236
-289
lines changed

4 files changed

+236
-289
lines changed

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 137 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -313,20 +313,6 @@ def __init__(
313313
self.hysteresis = hysteresis
314314
self.min_loss_scale = min_loss_scale
315315

316-
def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
317-
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
318-
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
319-
config = os.environ[self.DEEPSPEED_ENV_VAR]
320-
if isinstance(config, (str, Path)):
321-
if not os.path.isfile(config):
322-
raise MisconfigurationException(
323-
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
324-
)
325-
with open(config) as f:
326-
config = json.load(f)
327-
assert isinstance(config, dict) or config is None
328-
return config
329-
330316
@override
331317
def setup_environment(self) -> None:
332318
if not isinstance(self.accelerator, CUDAAccelerator):
@@ -343,12 +329,10 @@ def setup_distributed(self) -> None:
343329
reset_seed()
344330
self.set_world_ranks()
345331
self._init_deepspeed_distributed()
346-
if not self._config_initialized:
347-
self._format_config()
348-
self._config_initialized = True
349332

350333
@override
351334
def setup(self, trainer: "pl.Trainer") -> None:
335+
self._init_config_if_needed()
352336
assert self.accelerator is not None
353337
self.accelerator.setup(trainer)
354338
# we set the device so that optimizers can be created with distributed comms.
@@ -529,7 +513,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
529513
def model_sharded_context(self) -> Generator[None, None, None]:
530514
import deepspeed
531515

532-
assert self._config_initialized
516+
self._init_config_if_needed()
533517
with deepspeed.zero.Init(
534518
enabled=self.zero_stage_3,
535519
remote_device=self.remote_device,
@@ -610,134 +594,6 @@ def handles_gradient_accumulation(self) -> bool:
610594
"""Whether the strategy handles gradient accumulation internally."""
611595
return True
612596

613-
def _format_config(self) -> None:
614-
if self.config is None:
615-
raise MisconfigurationException(
616-
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
617-
" See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
618-
)
619-
self._format_batch_size_and_grad_accum_config()
620-
_format_precision_config(
621-
config=self.config,
622-
precision=self.precision_plugin.precision,
623-
loss_scale=self.loss_scale,
624-
loss_scale_window=self.loss_scale_window,
625-
min_loss_scale=self.min_loss_scale,
626-
initial_scale_power=self.initial_scale_power,
627-
hysteresis=self.hysteresis,
628-
)
629-
630-
def _format_batch_size_and_grad_accum_config(self) -> None:
631-
# TODO: Using Fabric, we do not support these variables within the config
632-
assert isinstance(self.config, dict)
633-
if self.lightning_module is None:
634-
return
635-
636-
if "gradient_accumulation_steps" in self.config:
637-
raise MisconfigurationException(
638-
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"
639-
" as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
640-
)
641-
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
642-
if "train_micro_batch_size_per_gpu" not in self.config:
643-
batch_size = self._auto_select_batch_size()
644-
self.config["train_micro_batch_size_per_gpu"] = batch_size
645-
if "gradient_clipping" not in self.config:
646-
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
647-
648-
def _auto_select_batch_size(self) -> int:
649-
import deepspeed
650-
651-
# train_micro_batch_size_per_gpu is used for throughput logging purposes
652-
# by default we try to use the batch size of the loader
653-
assert self.lightning_module is not None
654-
batch_size = 1
655-
data_source = self.lightning_module.trainer.fit_loop._data_source
656-
if data_source.is_defined():
657-
try:
658-
train_dataloader = data_source.dataloader()
659-
if hasattr(train_dataloader, "batch_sampler"):
660-
batch_size = train_dataloader.batch_sampler.batch_size
661-
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
662-
# to have been called before
663-
except Exception:
664-
if self.global_rank == 0:
665-
deepspeed.utils.logging.logger.warning(
666-
"Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. "
667-
"To ensure DeepSpeed logging remains correct, please manually pass the strategy with the "
668-
"batch size, `Trainer(strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=batch_size))`."
669-
)
670-
return batch_size
671-
672-
def _create_default_config(
673-
self,
674-
zero_optimization: bool,
675-
zero_allow_untested_optimizer: bool,
676-
logging_batch_size_per_gpu: Union[str, int],
677-
partition_activations: bool,
678-
cpu_checkpointing: bool,
679-
contiguous_memory_optimization: bool,
680-
synchronize_checkpoint_boundary: bool,
681-
offload_optimizer: bool,
682-
offload_parameters: bool,
683-
nvme_path: str,
684-
offload_params_device: str,
685-
params_buffer_count: int,
686-
params_buffer_size: int,
687-
max_in_cpu: int,
688-
offload_optimizer_device: str,
689-
optimizer_buffer_count: int,
690-
pin_memory: bool,
691-
block_size: int,
692-
queue_depth: int,
693-
single_submit: bool,
694-
overlap_events: bool,
695-
thread_count: int,
696-
**zero_kwargs: Any,
697-
) -> Dict:
698-
cfg = {
699-
"activation_checkpointing": {
700-
"partition_activations": partition_activations,
701-
"cpu_checkpointing": cpu_checkpointing,
702-
"contiguous_memory_optimization": contiguous_memory_optimization,
703-
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
704-
},
705-
"aio": {
706-
"block_size": block_size,
707-
"queue_depth": queue_depth,
708-
"single_submit": single_submit,
709-
"overlap_events": overlap_events,
710-
"thread_count": thread_count,
711-
},
712-
}
713-
if zero_optimization:
714-
zero_config = zero_kwargs
715-
716-
if offload_optimizer:
717-
zero_config["offload_optimizer"] = {
718-
"device": offload_optimizer_device,
719-
"nvme_path": nvme_path,
720-
"buffer_count": optimizer_buffer_count,
721-
"pin_memory": pin_memory,
722-
}
723-
if offload_parameters:
724-
zero_config["offload_param"] = {
725-
"device": offload_params_device,
726-
"nvme_path": nvme_path,
727-
"buffer_count": params_buffer_count,
728-
"buffer_size": params_buffer_size,
729-
"max_in_cpu": max_in_cpu,
730-
"pin_memory": pin_memory,
731-
}
732-
cfg = {
733-
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
734-
"zero_optimization": zero_config,
735-
**cfg,
736-
}
737-
if logging_batch_size_per_gpu != "auto":
738-
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
739-
return cfg
740-
741597
@property
742598
def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine":
743599
return self.model
@@ -915,3 +771,138 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
915771
offload_params_device="nvme",
916772
offload_optimizer_device="nvme",
917773
)
774+
775+
def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
776+
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
777+
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
778+
config = os.environ[self.DEEPSPEED_ENV_VAR]
779+
if isinstance(config, (str, Path)):
780+
if not os.path.isfile(config):
781+
raise MisconfigurationException(
782+
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
783+
)
784+
with open(config) as f:
785+
config = json.load(f)
786+
assert isinstance(config, dict) or config is None
787+
return config
788+
789+
def _init_config_if_needed(self) -> None:
790+
if not self._config_initialized:
791+
self._format_config()
792+
self._config_initialized = True
793+
794+
def _format_config(self) -> None:
795+
if self.config is None:
796+
raise MisconfigurationException(
797+
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
798+
" See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
799+
)
800+
self._format_batch_size_and_grad_accum_config()
801+
_format_precision_config(
802+
config=self.config,
803+
precision=self.precision_plugin.precision,
804+
loss_scale=self.loss_scale,
805+
loss_scale_window=self.loss_scale_window,
806+
min_loss_scale=self.min_loss_scale,
807+
initial_scale_power=self.initial_scale_power,
808+
hysteresis=self.hysteresis,
809+
)
810+
811+
def _create_default_config(
812+
self,
813+
zero_optimization: bool,
814+
zero_allow_untested_optimizer: bool,
815+
logging_batch_size_per_gpu: Union[str, int],
816+
partition_activations: bool,
817+
cpu_checkpointing: bool,
818+
contiguous_memory_optimization: bool,
819+
synchronize_checkpoint_boundary: bool,
820+
offload_optimizer: bool,
821+
offload_parameters: bool,
822+
nvme_path: str,
823+
offload_params_device: str,
824+
params_buffer_count: int,
825+
params_buffer_size: int,
826+
max_in_cpu: int,
827+
offload_optimizer_device: str,
828+
optimizer_buffer_count: int,
829+
pin_memory: bool,
830+
block_size: int,
831+
queue_depth: int,
832+
single_submit: bool,
833+
overlap_events: bool,
834+
thread_count: int,
835+
**zero_kwargs: Any,
836+
) -> Dict:
837+
cfg = {
838+
"activation_checkpointing": {
839+
"partition_activations": partition_activations,
840+
"cpu_checkpointing": cpu_checkpointing,
841+
"contiguous_memory_optimization": contiguous_memory_optimization,
842+
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
843+
},
844+
"aio": {
845+
"block_size": block_size,
846+
"queue_depth": queue_depth,
847+
"single_submit": single_submit,
848+
"overlap_events": overlap_events,
849+
"thread_count": thread_count,
850+
},
851+
}
852+
if zero_optimization:
853+
zero_config = zero_kwargs
854+
855+
if offload_optimizer:
856+
zero_config["offload_optimizer"] = {
857+
"device": offload_optimizer_device,
858+
"nvme_path": nvme_path,
859+
"buffer_count": optimizer_buffer_count,
860+
"pin_memory": pin_memory,
861+
}
862+
if offload_parameters:
863+
zero_config["offload_param"] = {
864+
"device": offload_params_device,
865+
"nvme_path": nvme_path,
866+
"buffer_count": params_buffer_count,
867+
"buffer_size": params_buffer_size,
868+
"max_in_cpu": max_in_cpu,
869+
"pin_memory": pin_memory,
870+
}
871+
cfg = {
872+
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
873+
"zero_optimization": zero_config,
874+
**cfg,
875+
}
876+
if logging_batch_size_per_gpu != "auto":
877+
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
878+
return cfg
879+
880+
def _format_batch_size_and_grad_accum_config(self) -> None:
881+
# TODO: Using Fabric, we do not support these variables within the config
882+
assert isinstance(self.config, dict)
883+
if self.lightning_module is None:
884+
return
885+
886+
if "gradient_accumulation_steps" in self.config:
887+
raise MisconfigurationException(
888+
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"
889+
" as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
890+
)
891+
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
892+
if "train_micro_batch_size_per_gpu" not in self.config:
893+
batch_size = self._auto_select_batch_size()
894+
self.config["train_micro_batch_size_per_gpu"] = batch_size
895+
if "gradient_clipping" not in self.config:
896+
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
897+
898+
def _auto_select_batch_size(self) -> int:
899+
# train_micro_batch_size_per_gpu is used for throughput logging purposes
900+
# by default we try to use the batch size of the loader
901+
assert self.lightning_module is not None
902+
batch_size = 1
903+
data_source = self.lightning_module.trainer.fit_loop._data_source
904+
if data_source.is_defined():
905+
train_dataloader = data_source.dataloader()
906+
if hasattr(train_dataloader, "batch_sampler"):
907+
batch_size = train_dataloader.batch_sampler.batch_size
908+
return batch_size

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def _run(
946946
self.strategy.setup_environment()
947947
self.__setup_profiler()
948948

949-
call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment
949+
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
950950
log.debug(f"{self.__class__.__name__}: configuring model")
951951
call._call_configure_model(self)
952952

tests/tests_pytorch/models/test_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,10 @@ def training_step(self, batch, batch_idx):
477477
expected = [
478478
{"name": "configure_callbacks"},
479479
{"name": "prepare_data"},
480-
# DeepSpeed needs the batch size to figure out throughput logging
481-
*([{"name": "train_dataloader"}] if using_deepspeed else []),
482480
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
483481
{"name": "setup", "kwargs": {"stage": "fit"}},
482+
# DeepSpeed needs the batch size to figure out throughput logging
483+
*([{"name": "train_dataloader"}] if using_deepspeed else []),
484484
{"name": "configure_model"},
485485
{"name": "configure_optimizers"},
486486
{"name": "Callback.on_fit_start", "args": (trainer, model)},

0 commit comments

Comments
 (0)