From decb98a96712b0f4b21dae7068716e58e86ce098 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 19 Oct 2024 02:00:03 +1100 Subject: [PATCH 1/8] enhance 3d-party devices in mix-precision --- .../source-pytorch/extensions/accelerator.rst | 39 ++++++++++++++++++- .../fabric/accelerators/accelerator.py | 5 +++ src/lightning/fabric/accelerators/cpu.py | 5 +++ src/lightning/fabric/accelerators/cuda.py | 5 +++ src/lightning/fabric/accelerators/mps.py | 5 +++ src/lightning/fabric/accelerators/xla.py | 5 +++ src/lightning/fabric/connector.py | 9 ++++- src/lightning/fabric/plugins/precision/amp.py | 9 ++++- .../fabric/plugins/precision/fsdp.py | 9 ++++- src/lightning/fabric/strategies/ddp.py | 8 +++- src/lightning/fabric/strategies/deepspeed.py | 16 ++++++-- src/lightning/fabric/strategies/strategy.py | 5 ++- .../pytorch/accelerators/accelerator.py | 5 +++ src/lightning/pytorch/accelerators/cpu.py | 5 +++ src/lightning/pytorch/accelerators/cuda.py | 5 +++ src/lightning/pytorch/accelerators/mps.py | 5 +++ .../pytorch/plugins/precision/amp.py | 9 ++++- .../pytorch/plugins/precision/fsdp.py | 9 ++++- src/lightning/pytorch/strategies/ddp.py | 8 +++- src/lightning/pytorch/strategies/deepspeed.py | 10 +++-- src/lightning/pytorch/strategies/strategy.py | 5 ++- .../connectors/accelerator_connector.py | 24 ++++++++++-- .../accelerators/test_registry.py | 4 ++ tests/tests_fabric/test_connector.py | 4 ++ .../connectors/test_accelerator_connector.py | 4 ++ 25 files changed, 193 insertions(+), 24 deletions(-) diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 93dc467b02921..dcedde8c6905c 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -36,29 +36,57 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc .. code-block:: python + import torch import xpulib + from functools import lru_cache + from typing import Any, Dict, Union + from lightning.pytorch.accelerators.accelerator import Accelerator + + from typing_extensions import override + class XPUAccelerator(Accelerator): """Support for a hypothetical XPU, optimized for large-scale machine learning.""" + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + ValueError: + If the selected device is not of type hypothetical XPU. + """ + if device.type != "xpu": + raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.") + if device.index is None: + device = torch.device("xpu", 0) + xpulib.set_device(device.index) + + @override + def teardown(self) -> None: + xpulib.empty_cache() + @staticmethod + @override def parse_devices(devices: Any) -> Any: # Put parsing logic here how devices can be passed into the Trainer # via the `devices` argument return devices @staticmethod + @override def get_parallel_devices(devices: Any) -> Any: # Here, convert the device indices to actual device objects return [torch.device("xpu", idx) for idx in devices] @staticmethod + @override def auto_device_count() -> int: # Return a value for auto-device selection when `Trainer(devices="auto")` return xpulib.available_devices() @staticmethod + @override def is_available() -> bool: return xpulib.is_available() @@ -66,15 +94,21 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc # Return optional device statistics for loggers return {} + @staticmethod + @override + def get_device() -> str: + return "xpu" + Finally, add the XPUAccelerator to the Trainer: .. code-block:: python from lightning.pytorch import Trainer - + from lightning.pytorch.strategies import DDPStrategy accelerator = XPUAccelerator() - trainer = Trainer(accelerator=accelerator, devices=2) + strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2)) + trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2) :doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator. @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes, ... @classmethod + @override def register_accelerators(cls, accelerator_registry): accelerator_registry.register( "xpu", diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index 3a8aa85ad041d..84ef97e514bbc 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any: def get_parallel_devices(devices: Any) -> Any: """Gets parallel devices for the Accelerator.""" + @staticmethod + @abstractmethod + def get_device() -> Any: + """Get the device for the current Accelerator.""" + @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 1bcec1b2ac278..e019ea100ee8b 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -50,6 +50,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices + @staticmethod + @override + def get_device() -> str: + return "cpu" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 4afc9be723fc2..420f645dc9cb6 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -55,6 +55,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] + @staticmethod + @override + def get_device() -> str: + return "cuda" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 75497169cda0f..29beb97fc9c9a 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi assert parsed_devices is not None return [torch.device("mps", i) for i in range(len(parsed_devices))] + @staticmethod + @override + def get_device() -> str: + return "mps" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 38d7380dc7905..5055a29398f60 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy + @staticmethod + @override + def get_device() -> str: + return "xla" + @staticmethod @override # XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9fb66255830c6..3cf9ed681fe6b 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._set_parallel_devices_and_init_accelerator() @@ -461,7 +463,10 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_input, # type: ignore[arg-type] + device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( @@ -493,6 +498,8 @@ def _check_and_init_precision(self) -> Precision: else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c624e821af28c..1e2f54e2c7270 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -50,7 +50,14 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 179fc21cdd90d..43570373a39b1 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -48,13 +48,16 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -110,7 +113,9 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return self.tensor_init_context() @override diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c38780655ce6e..e7456fd6a8ca5 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = ( + getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( + getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() + ) + if device_ids is not None + else nullcontext() + ) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index e71b8e2db3d58..e74eb39bd79a5 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -506,7 +506,11 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - torch.cuda.empty_cache() + if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu": + getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + else: + torch.cuda.empty_cache() + _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -616,10 +620,14 @@ def _initialize_engine( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..21a5aaffa9900 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -325,7 +325,10 @@ def load_checkpoint( given, the full checkpoint will be returned. """ - torch.cuda.empty_cache() + if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu": + getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + else: + torch.cuda.empty_cache() checkpoint = self.checkpoint_io.load_checkpoint(path) if not state: return checkpoint diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c..96a3941af97f3 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """ raise NotImplementedError + + @staticmethod + def get_device() -> str: + """Get the device for the current process.""" + raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363d11..ab6304053f314 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cpu" + # CPU device metrics _CPU_VM_PERCENT = "cpu_vm_percent" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468ee..cfb85cb2c2990 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cuda" + def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 6efe6292de624..d8bda9dae8087 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "mps" + # device metrics _VM_PERCENT = "M1_vm_percent" diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912b63..eb17c33a902de 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -50,7 +50,14 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index e6c684967ed40..280defe04ff44 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -47,13 +47,16 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -119,7 +122,9 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return _DtypeContextManager(self._desired_input_dtype) @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..c16310cd65245 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,7 +190,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = ( + getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( + getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() + ) + if device_ids is not None + else nullcontext() + ) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1eaa5bab75fbe..8bbd3dd0191c7 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -316,10 +316,14 @@ def __init__( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497f59..e2c6642f6b932 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,7 +363,10 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - torch.cuda.empty_cache() + if isinstance(self.accelerator, pl.accelerators.Accelerator) and self.accelerator.get_device() != "cpu": + getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + else: + torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 06f3ee366bcaa..c4d7526b85691 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) self._set_parallel_devices_and_init_accelerator() @@ -301,13 +303,15 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" - if self._strategy_flag.parallel_devices[0].type == "cuda": + elif self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + else: + pass # 3rd party accelerator self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -457,12 +461,19 @@ def _check_strategy_and_fallback(self) -> None: strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag if ( - strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + (strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy) + and self._accelerator_flag not in ("cuda", "gpu") + and isinstance(self._accelerator_flag, str) + ): raise ValueError( f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" f" {self._accelerator_flag}" ) + if isinstance(self._accelerator_flag, Accelerator): + Warning( + f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`." + f" Please ensure it is compatible with the selected strategy `{strategy_flag}`." + ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" @@ -496,7 +507,10 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_flag, # type: ignore[arg-type] + device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + ) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": @@ -520,6 +534,8 @@ def _check_and_init_precision(self) -> Precision: f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index e8f39b6e83406..2540bde18ce7d 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -44,6 +44,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return ["foo"] * devices + @staticmethod + def get_device(): + return "foo" + @staticmethod def auto_device_count(): return 3 diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 08d6dbb45ed91..22a998962141b 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -179,6 +179,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return [torch.device("cpu")] * devices + @staticmethod + def get_device() -> str: + return "cpu" + @staticmethod def auto_device_count() -> int: return 1 diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 65c5777e28fed..621fd9106019b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -192,6 +192,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return [torch.device("cpu")] * devices + @staticmethod + def get_device() -> str: + return "cpu" + @staticmethod def auto_device_count() -> int: return 1 From d212131c8e4fab8d63c5c1adef5cdc9319c48a10 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Fri, 20 Dec 2024 00:59:50 +0000 Subject: [PATCH 2/8] update codes to device_type --- docs/source-pytorch/extensions/accelerator.rst | 2 +- src/lightning/fabric/accelerators/accelerator.py | 2 +- src/lightning/fabric/accelerators/cpu.py | 2 +- src/lightning/fabric/accelerators/cuda.py | 2 +- src/lightning/fabric/accelerators/mps.py | 2 +- src/lightning/fabric/accelerators/xla.py | 2 +- src/lightning/fabric/connector.py | 4 ++-- src/lightning/fabric/plugins/precision/fsdp.py | 6 +++--- src/lightning/fabric/strategies/deepspeed.py | 6 +++--- src/lightning/fabric/strategies/strategy.py | 2 +- src/lightning/pytorch/accelerators/accelerator.py | 2 +- src/lightning/pytorch/accelerators/cpu.py | 2 +- src/lightning/pytorch/accelerators/cuda.py | 2 +- src/lightning/pytorch/accelerators/mps.py | 2 +- src/lightning/pytorch/plugins/precision/fsdp.py | 6 +++--- src/lightning/pytorch/strategies/deepspeed.py | 2 +- src/lightning/pytorch/strategies/strategy.py | 2 +- .../pytorch/trainer/connectors/accelerator_connector.py | 4 ++-- tests/tests_fabric/accelerators/test_registry.py | 2 +- tests/tests_fabric/test_connector.py | 2 +- .../trainer/connectors/test_accelerator_connector.py | 2 +- 21 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index dcedde8c6905c..4ea3b639600a9 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -96,7 +96,7 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "xpu" diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index 84ef97e514bbc..bdde00d212edd 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -48,7 +48,7 @@ def get_parallel_devices(devices: Any) -> Any: @staticmethod @abstractmethod - def get_device() -> Any: + def get_device_type() -> Any: """Get the device for the current Accelerator.""" @staticmethod diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index e140dfc2a6307..ba9ff07c85d09 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -52,7 +52,7 @@ def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "cpu" @staticmethod diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 83d8088395fb7..d09d0a3fc0097 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -57,7 +57,7 @@ def get_parallel_devices(devices: list[int]) -> list[torch.device]: @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "cuda" @staticmethod diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 7dec50683902e..f8a4b68543dee 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -62,7 +62,7 @@ def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.devi @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "mps" @staticmethod diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 71a65ddebd1fd..6a74e207edaa3 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -66,7 +66,7 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "xla" @staticmethod diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 33d421241d8da..72dbb4854484c 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -466,7 +466,7 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( precision=self._precision_input, # type: ignore[arg-type] - device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + device_type=self._accelerator_flag.get_device_type() if isinstance(self._accelerator_flag, Accelerator) else None, ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: @@ -500,7 +500,7 @@ def _check_and_init_precision(self) -> Precision: ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self._accelerator_flag, Accelerator): - device = self._accelerator_flag.get_device() + device = self._accelerator_flag.get_device_type() return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 5253740d63b2a..8fe53aaa9bb7f 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -50,7 +50,7 @@ class FSDPPrecision(Precision): """ def __init__( - self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device_type: Optional[str] = None ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: @@ -58,7 +58,7 @@ def __init__( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) - self.device = device if device is not None else "cuda" + self.device_type = device_type if device_type is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -121,7 +121,7 @@ def module_init_context(self) -> AbstractContextManager: def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast( - self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) ) return self.tensor_init_context() diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 341eedae525a2..bf659a4c44a6a 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -511,8 +511,8 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu": - getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu": + getattr(torch, self.root_device.type).empty_cache() else: torch.cuda.empty_cache() @@ -629,7 +629,7 @@ def setup_environment(self) -> None: if ( not isinstance(self.accelerator, CUDAAccelerator) - ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] + ) and self.accelerator.get_device_type() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " f"but `{self.accelerator.__class__.__name__}` is used." diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 85da73a3e8bd2..ebb1048670eaa 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -326,7 +326,7 @@ def load_checkpoint( given, the full checkpoint will be returned. """ - if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu": + if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu": getattr(torch, self.root_device.type.split(":")[0]).empty_cache() else: torch.cuda.empty_cache() diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index a6a3b3a54f176..48e7bcf160834 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -47,6 +47,6 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: raise NotImplementedError @staticmethod - def get_device() -> str: + def get_device_type() -> str: """Get the device for the current process.""" raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index f6ec6479fb5de..bfe4551311aa4 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -82,7 +82,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "cpu" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 4426fd1a771a4..2ec1d34fa9e04 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -115,7 +115,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "cuda" diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 5676cc79b52e1..a82eb1df6e439 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -89,7 +89,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No @staticmethod @override - def get_device() -> str: + def get_device_type() -> str: return "mps" diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index d6781eebf3341..a2f1ffdcb3392 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -50,7 +50,7 @@ class FSDPPrecision(Precision): """ def __init__( - self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device_type: Optional[str] = None ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: @@ -58,7 +58,7 @@ def __init__( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) - self.device = device if device is not None else "cuda" + self.device_type = device_type if device_type is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -131,7 +131,7 @@ def module_init_context(self) -> AbstractContextManager: def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast( - self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) ) return _DtypeContextManager(self._desired_input_dtype) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index af1351c43be84..5a87c02c9f5d4 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -325,7 +325,7 @@ def setup_environment(self) -> None: if ( not isinstance(self.accelerator, CUDAAccelerator) - ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] + ) and self.accelerator.get_device_type() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " f"but `{self.accelerator.__class__.__name__}` is used." diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 2fbc00fe13038..ad48f00721de4 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -364,7 +364,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: - if isinstance(self.accelerator, pl.accelerators.Accelerator) and self.accelerator.get_device() != "cpu": + if isinstance(self.accelerator, pl.accelerators.Accelerator) and self.accelerator.get_device_type() != "cpu": getattr(torch, self.root_device.type.split(":")[0]).empty_cache() else: torch.cuda.empty_cache() diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 3b773d72de118..475fba22ce497 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -510,7 +510,7 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( precision=self._precision_flag, # type: ignore[arg-type] - device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + device_type=self._accelerator_flag.get_device_type() if isinstance(self._accelerator_flag, Accelerator) else None, ) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore @@ -536,7 +536,7 @@ def _check_and_init_precision(self) -> Precision: ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self._accelerator_flag, Accelerator): - device = self._accelerator_flag.get_device() + device = self._accelerator_flag.get_device_type() return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index d4cb6abda7f45..383bcd3a9c0a6 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -45,7 +45,7 @@ def get_parallel_devices(devices): return ["foo"] * devices @staticmethod - def get_device(): + def get_device_type(): return "foo" @staticmethod diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 4bc4bed4a6666..17817dee64abe 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -180,7 +180,7 @@ def get_parallel_devices(devices): return [torch.device("cpu")] * devices @staticmethod - def get_device() -> str: + def get_device_type() -> str: return "cpu" @staticmethod diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index dee229a0a469c..7de9470fbe186 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -193,7 +193,7 @@ def get_parallel_devices(devices): return [torch.device("cpu")] * devices @staticmethod - def get_device() -> str: + def get_device_type() -> str: return "cpu" @staticmethod From b89ca6e80ef61db647cd637e9254d1fce7c40e18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 01:00:23 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/connector.py | 4 +++- src/lightning/fabric/plugins/precision/fsdp.py | 5 ++++- src/lightning/pytorch/plugins/precision/fsdp.py | 5 ++++- .../pytorch/trainer/connectors/accelerator_connector.py | 4 +++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 72dbb4854484c..1f679ba7ffe1a 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -466,7 +466,9 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( precision=self._precision_input, # type: ignore[arg-type] - device_type=self._accelerator_flag.get_device_type() if isinstance(self._accelerator_flag, Accelerator) else None, + device_type=self._accelerator_flag.get_device_type() + if isinstance(self._accelerator_flag, Accelerator) + else None, ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 8fe53aaa9bb7f..3f1cbe8fa3e8b 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -50,7 +50,10 @@ class FSDPPrecision(Precision): """ def __init__( - self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device_type: Optional[str] = None + self, + precision: _PRECISION_INPUT, + scaler: Optional["ShardedGradScaler"] = None, + device_type: Optional[str] = None, ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index a2f1ffdcb3392..1f5b92c8285ce 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -50,7 +50,10 @@ class FSDPPrecision(Precision): """ def __init__( - self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device_type: Optional[str] = None + self, + precision: _PRECISION_INPUT, + scaler: Optional["ShardedGradScaler"] = None, + device_type: Optional[str] = None, ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 475fba22ce497..af643cd7a284c 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -510,7 +510,9 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( precision=self._precision_flag, # type: ignore[arg-type] - device_type=self._accelerator_flag.get_device_type() if isinstance(self._accelerator_flag, Accelerator) else None, + device_type=self._accelerator_flag.get_device_type() + if isinstance(self._accelerator_flag, Accelerator) + else None, ) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore From d352b4c19b3996632085ed15ab3287d7a5beb1c5 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Fri, 20 Dec 2024 02:06:21 +0000 Subject: [PATCH 4/8] update deepspeed --- src/lightning/fabric/accelerators/accelerator.py | 2 +- src/lightning/fabric/strategies/deepspeed.py | 12 ++++++++---- src/lightning/pytorch/strategies/deepspeed.py | 9 +++++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index bdde00d212edd..1017fd5c368ba 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -49,7 +49,7 @@ def get_parallel_devices(devices: Any) -> Any: @staticmethod @abstractmethod def get_device_type() -> Any: - """Get the device for the current Accelerator.""" + """Get the device_type for the current Accelerator.""" @staticmethod @abstractmethod diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index bf659a4c44a6a..82c636a282b25 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -299,6 +299,12 @@ def __init__( self._deepspeed_engine: Optional[DeepSpeedEngine] = None + if isinstance(self.accelerator, Accelerator): + self.device_type = self.accelerator.get_device_type() + else: + self.device_type = "cuda" + self.torch_lib = getattr(torch, self.device_type) + @property def zero_stage_3(self) -> bool: assert isinstance(self.config, dict) @@ -511,10 +517,8 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu": - getattr(torch, self.root_device.type).empty_cache() - else: - torch.cuda.empty_cache() + if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache): + self.torch_lib.empty_cache() _, client_state = engine.load_checkpoint( path, diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 5a87c02c9f5d4..df9f97b6fd45c 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -319,6 +319,12 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale + try: + self.device_type = self.accelerator.get_device_type() + except Exception: + self.device_type = "cuda" + self.torch_lib = getattr(torch, self.device_type) + @override def setup_environment(self) -> None: from deepspeed.runtime.utils import get_accelerator @@ -672,6 +678,9 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache): + self.torch_lib.empty_cache() + _, client_state = self.deepspeed_engine.load_checkpoint( checkpoint_path, load_optimizer_states=is_fitting, From 5cdd9e72b10e46fea6e7ce33976abf7a4ef3ec56 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Fri, 20 Dec 2024 03:19:25 +0000 Subject: [PATCH 5/8] update ddp --- src/lightning/fabric/strategies/ddp.py | 20 +++++++++++++------- src/lightning/pytorch/strategies/ddp.py | 20 +++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 6fa74f42e18dc..9faa07b8b2f56 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,13 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = ( - getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( - getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() - ) - if device_ids is not None - else nullcontext() - ) + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) @@ -234,6 +228,18 @@ def _set_world_ranks(self) -> None: def _determine_ddp_device_ids(self) -> Optional[list[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] + def _create_stream_context(self, device_ids=None): + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPBackwardSyncControl(_BackwardSyncControl): @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9e46549ed5f84..f36ad79a74194 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,13 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = ( - getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( - getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() - ) - if device_ids is not None - else nullcontext() - ) + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @@ -424,6 +418,18 @@ def teardown(self) -> None: super().teardown() + def _create_stream_context(self, device_ids=None): + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPForwardRedirection(_ForwardRedirection): @override From 06a3303742ec09b2af8cebfdd8e74e615bb8ac5c Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Fri, 20 Dec 2024 03:20:18 +0000 Subject: [PATCH 6/8] update amp --- src/lightning/fabric/plugins/precision/amp.py | 2 +- src/lightning/pytorch/plugins/precision/amp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index a7bf1113d1497..d89ff1597e4bc 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -56,7 +56,7 @@ def __init__( if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed": diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index d610057e2adda..b6a6fd12771ca 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -56,7 +56,7 @@ def __init__( if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed": From 4f7107a84031ff48125c6fbe7f50b2e3bb36cc68 Mon Sep 17 00:00:00 2001 From: zhiyuan li Date: Mon, 13 Jan 2025 20:22:47 +0800 Subject: [PATCH 7/8] update ignore --- src/lightning/pytorch/strategies/ddp.py | 4 ++-- src/lightning/pytorch/strategies/deepspeed.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index f36ad79a74194..4b8ab944e37f1 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -228,7 +228,7 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") - if self._model_averaging_period is None: + if self._model_averaging_period is None: # type: ignore[no-untyped-def] raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." ) @@ -418,7 +418,7 @@ def teardown(self) -> None: super().teardown() - def _create_stream_context(self, device_ids=None): + def _create_stream_context(self, device_ids=None): # type: ignore[no-untyped-def] """Create a stream context for the current device, if supported.""" torch_lib = getattr(torch, self.root_device.type) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4a1ae6f1766ae..4744986e96e74 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -320,7 +320,7 @@ def __init__( self.min_loss_scale = min_loss_scale try: - self.device_type = self.accelerator.get_device_type() + self.device_type = self.accelerator.get_device_type() # type: ignore[union-attr] except Exception: self.device_type = "cuda" self.torch_lib = getattr(torch, self.device_type) From ff1beaec7e4c9dd032942e9c6eb4ccd9f2d6117f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 12:23:09 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/ddp.py | 4 ++-- src/lightning/pytorch/strategies/deepspeed.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4b8ab944e37f1..f64d7b761b36d 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -228,7 +228,7 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") - if self._model_averaging_period is None: # type: ignore[no-untyped-def] + if self._model_averaging_period is None: # type: ignore[no-untyped-def] raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." ) @@ -418,7 +418,7 @@ def teardown(self) -> None: super().teardown() - def _create_stream_context(self, device_ids=None): # type: ignore[no-untyped-def] + def _create_stream_context(self, device_ids=None): # type: ignore[no-untyped-def] """Create a stream context for the current device, if supported.""" torch_lib = getattr(torch, self.root_device.type) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4744986e96e74..45ca752b045e3 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -320,7 +320,7 @@ def __init__( self.min_loss_scale = min_loss_scale try: - self.device_type = self.accelerator.get_device_type() # type: ignore[union-attr] + self.device_type = self.accelerator.get_device_type() # type: ignore[union-attr] except Exception: self.device_type = "cuda" self.torch_lib = getattr(torch, self.device_type)