Skip to content

Commit 887199a

Browse files
committed
refactor and add new examples
1 parent 8ce5287 commit 887199a

File tree

11 files changed

+837
-0
lines changed

11 files changed

+837
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# PyTorch Native FP8 Training with FSDP1/2 and Torch Compile using Custom Handler
2+
3+
This is an example of a ...
4+
5+
## Requirements
6+
7+
Install requirements by running
8+
9+
```bash
10+
sh setup.sh
11+
```
12+
13+
## Example
14+
15+
In this example we present
16+
17+
```bash
18+
19+
```
20+
21+
## Test the handlers
22+
23+
```bash
24+
# config the PYTHONPATH if needed
25+
# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH
26+
cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile
27+
pytest tests/*
28+
```
29+
30+
> **Warning**

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/__init__.py

Whitespace-only changes.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py
2+
import logging
3+
import operator
4+
from dataclasses import dataclass
5+
from typing import Dict, List, Union
6+
7+
import torch
8+
import torch.nn as nn
9+
from lightning_utilities.core.imports import compare_version
10+
11+
log = logging.getLogger(__name__)
12+
13+
14+
def is_sm89_or_later():
15+
# Float8 is only supported on SM89 or later (H100+ GPUs)
16+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
17+
18+
19+
# check https://github.com/pytorch/ao/blob/main/torchao/float8/config.py for more config details
20+
@dataclass
21+
class FP8Config:
22+
enable_fp8: bool = True
23+
enable_amax_init: bool = False
24+
scaling_type_input: str = "delayed"
25+
scaling_type_weight: str = "delayed"
26+
scaling_type_grad_output: str = "delayed"
27+
enable_fsdp_float8_all_gather: bool = False
28+
precompute_float8_dynamic_scale_for_fsdp: bool = False
29+
pad_inner_dim: bool = True
30+
emulate_fp8: bool = False # Set to True for testing without FP8 hardware
31+
enable_torch_compile: bool = True
32+
enable_pre_and_post_forward: bool = False
33+
34+
35+
# Define a map for module filter functions based on model name
36+
MODULE_FILTER_MAP = {
37+
"llama": lambda mod, fqn: isinstance(mod, nn.Linear) and "mlp" in fqn and "lm_head" not in fqn,
38+
"mixtral": lambda mod, fqn: isinstance(mod, nn.Linear)
39+
and "block_sparse_moe" in fqn
40+
and "block_sparse_moe.gate" not in fqn
41+
and "lm_head" not in fqn,
42+
"default": lambda mod, fqn: isinstance(mod, nn.Linear), # Default filter
43+
}
44+
45+
46+
class Float8TrainingHandler:
47+
"""Handler for configuring models for FP8 training using torchao."""
48+
49+
def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bool]):
50+
"""Initializes the handler for FP8 training and configuration.
51+
52+
Args:
53+
args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile.
54+
model_path (str): The path to the model. Typically used for determining model-specific settings.
55+
parallel_dims (Dict[str, bool]): Dictionary specifying parallelization settings, such as whether DP shard is enabled.
56+
57+
Example Usage:
58+
fp8_config = FP8Config(
59+
enable_fp8=True,
60+
enable_amax_init=True,
61+
scaling_type_input="delayed",
62+
scaling_type_weight="delayed",
63+
scaling_type_grad_output="delayed",
64+
enable_fsdp_float8_all_gather=False,
65+
precompute_float8_dynamic_scale_for_fsdp=False,
66+
pad_inner_dim=True,
67+
emulate_fp8=False, # Set to True for testing without FP8 hardware
68+
enable_torch_compile=True,
69+
enable_pre_and_post_forward=False,
70+
)
71+
72+
parallel_dims = {"dp_shard_enabled": False}
73+
handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims)
74+
75+
"""
76+
self.model_path = model_path
77+
self.args = args
78+
self.parallel_dims = parallel_dims
79+
self.compile = args.enable_torch_compile
80+
self.enable_fp8 = args.enable_fp8
81+
82+
if not self.enable_fp8:
83+
log.warning("Fp8 is disabled here")
84+
return
85+
86+
if not is_sm89_or_later() and not args.emulate_fp8:
87+
log.error("Failed to swap to Float8Linear because float8 is only supported on SM89 or later (H100+ GPUs)")
88+
raise RuntimeError("Float8Linear operation is not supported on the current hardware.")
89+
90+
# Check if torchao is installed and version is >= 0.5.0
91+
try:
92+
compare_version("torchao", operator.ge, "0.6.1")
93+
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
94+
except ImportError as e:
95+
log.error(str(e))
96+
raise
97+
98+
# Configure Float8LinearConfig parameters from args
99+
scaling_type_input = ScalingType(args.scaling_type_input)
100+
scaling_type_weight = ScalingType(args.scaling_type_weight)
101+
scaling_type_grad_output = ScalingType(args.scaling_type_grad_output)
102+
103+
enable_fsdp_float8_all_gather = (
104+
parallel_dims.get("dp_shard_enabled", False) and args.enable_fsdp_float8_all_gather
105+
)
106+
107+
enable_amax_init = args.enable_amax_init
108+
self.config = Float8LinearConfig(
109+
enable_amax_init=enable_amax_init,
110+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
111+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
112+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
113+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
114+
enable_pre_and_post_forward=args.enable_pre_and_post_forward,
115+
pad_inner_dim=args.pad_inner_dim,
116+
emulate=args.emulate_fp8,
117+
)
118+
119+
# For precompute_float8_dynamic_scale_for_fsdp
120+
self.precompute_scale = enable_fsdp_float8_all_gather and args.precompute_float8_dynamic_scale_for_fsdp
121+
122+
# For sync_float8_amax_and_scale_history
123+
self.delayed_scaling = (
124+
scaling_type_input == ScalingType.DELAYED
125+
or scaling_type_weight == ScalingType.DELAYED
126+
or scaling_type_grad_output == ScalingType.DELAYED
127+
)
128+
self._sync_float8_amax_and_scale_history = None
129+
130+
log.info("Float8 training active")
131+
132+
def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None):
133+
"""Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model
134+
in place.
135+
136+
Args:
137+
model (nn.Module): The model whose layers should be converted.
138+
module_filter_fn (callable, optional): A function to filter which modules should be replaced.
139+
Defaults to a model-specific filter based on `model_path`.
140+
141+
"""
142+
if not self.enable_fp8:
143+
log.warning("FP8 is disabled, so layers will not be replaced.")
144+
return
145+
146+
log.warning("Enabling FP8 Training")
147+
148+
# Use the provided filter function or select from the map
149+
if module_filter_fn is None:
150+
model_path_lower = self.model_path.lower()
151+
module_filter_fn = next(
152+
(fn for key, fn in MODULE_FILTER_MAP.items() if key in model_path_lower),
153+
MODULE_FILTER_MAP["default"], # Default filter if no match is found
154+
)
155+
156+
from torchao.float8 import convert_to_float8_training
157+
158+
convert_to_float8_training(
159+
model,
160+
config=self.config,
161+
module_filter_fn=module_filter_fn,
162+
)
163+
log.info(
164+
f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}"
165+
)
166+
167+
def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, List[nn.Module]]):
168+
if not self.enable_fp8 or not self.precompute_scale:
169+
return
170+
171+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
172+
173+
models = [model] if isinstance(model, nn.Module) else model
174+
for m in models:
175+
precompute_float8_dynamic_scale_for_fsdp(m)
176+
177+
def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, List[nn.Module]]):
178+
if not self.enable_fp8 or not self.delayed_scaling:
179+
return
180+
181+
from torchao.float8 import sync_float8_amax_and_scale_history
182+
183+
# Cache the compiled function if necessary
184+
if self._sync_float8_amax_and_scale_history is None:
185+
if self.compile:
186+
self._sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history)
187+
else:
188+
self._sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
189+
190+
models = [model] if isinstance(model, nn.Module) else model
191+
for m in models:
192+
self._sync_float8_amax_and_scale_history(m)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import logging
2+
import operator
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING
5+
6+
import torch
7+
import torch.nn as nn
8+
from lightning_utilities.core.imports import compare_version
9+
10+
if TYPE_CHECKING:
11+
from torch.distributed.device_mesh import DeviceMesh
12+
13+
log = logging.getLogger(__name__)
14+
15+
16+
@dataclass
17+
class FSDP2Config:
18+
enable_cpu_offload: bool = False
19+
enable_gradient_checkpointing: bool = False
20+
21+
22+
class FSDP2Handler:
23+
"""Handler for wrapping the model layers with FSDP2.
24+
25+
Args:
26+
args (FSDP2Config): Configuration for FSDP2, including options for CPU offload and gradient checkpointing.
27+
device_mesh (DeviceMesh): Device mesh configuration for FSDP2 parallelism.
28+
29+
Attributes:
30+
args (FSDP2Config): Stores the FSDP2 configuration.
31+
device_mesh (DeviceMesh): Stores the device mesh configuration.
32+
33+
"""
34+
35+
def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"):
36+
self.args = args
37+
self.device_mesh = device_mesh
38+
39+
# Check PyTorch version for FSDP2 support (currently we require PyTorch >= 2.6.0)
40+
try:
41+
compare_version("torch", operator.ge, "2.6.0")
42+
except RuntimeError as e:
43+
log.error(str(e))
44+
raise
45+
46+
# Import necessary FSDP modules
47+
try:
48+
from torch.distributed._composable.fsdp import (
49+
CPUOffloadPolicy,
50+
MixedPrecisionPolicy,
51+
fully_shard,
52+
)
53+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
54+
checkpoint_wrapper,
55+
)
56+
57+
self.fully_shard = fully_shard
58+
self.checkpoint_wrapper = checkpoint_wrapper
59+
self.MixedPrecisionPolicy = MixedPrecisionPolicy
60+
self.CPUOffloadPolicy = CPUOffloadPolicy
61+
except ImportError as e:
62+
log.error(f"Failed to import FSDP modules: {e}")
63+
raise
64+
65+
def wrap_model(self, model: nn.Module):
66+
"""Wraps the model layers with FSDP configurations.
67+
68+
Args:
69+
model (nn.Module): The model to wrap.
70+
71+
Returns:
72+
nn.Module: The wrapped model.
73+
74+
"""
75+
dp_mesh = self.device_mesh["data_parallel"]
76+
assert dp_mesh.size() > 1, "FSDP requires at least two devices."
77+
78+
fsdp_policy = dict(
79+
mesh=dp_mesh,
80+
mp_policy=self.MixedPrecisionPolicy(
81+
param_dtype=torch.bfloat16,
82+
reduce_dtype=torch.float32,
83+
),
84+
)
85+
if self.args.enable_cpu_offload:
86+
fsdp_policy["offload_policy"] = self.CPUOffloadPolicy()
87+
88+
for layer_id, module in enumerate(model.model.layers):
89+
reshard_after_forward = layer_id < len(model.model.layers) - 1
90+
if self.args.enable_gradient_checkpointing:
91+
module = self.checkpoint_wrapper(module)
92+
self.fully_shard(
93+
module,
94+
**fsdp_policy,
95+
reshard_after_forward=reshard_after_forward,
96+
)
97+
model.model.layers[layer_id] = module
98+
99+
self.fully_shard(model, **fsdp_policy)
100+
return model

0 commit comments

Comments
 (0)