|
| 1 | +# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py |
| 2 | +import logging |
| 3 | +import operator |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Dict, List, Union |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +from lightning_utilities.core.imports import compare_version |
| 10 | + |
| 11 | +log = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +def is_sm89_or_later(): |
| 15 | + # Float8 is only supported on SM89 or later (H100+ GPUs) |
| 16 | + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
| 17 | + |
| 18 | + |
| 19 | +# check https://github.com/pytorch/ao/blob/main/torchao/float8/config.py for more config details |
| 20 | +@dataclass |
| 21 | +class FP8Config: |
| 22 | + enable_fp8: bool = True |
| 23 | + enable_amax_init: bool = False |
| 24 | + scaling_type_input: str = "delayed" |
| 25 | + scaling_type_weight: str = "delayed" |
| 26 | + scaling_type_grad_output: str = "delayed" |
| 27 | + enable_fsdp_float8_all_gather: bool = False |
| 28 | + precompute_float8_dynamic_scale_for_fsdp: bool = False |
| 29 | + pad_inner_dim: bool = True |
| 30 | + emulate_fp8: bool = False # Set to True for testing without FP8 hardware |
| 31 | + enable_torch_compile: bool = True |
| 32 | + enable_pre_and_post_forward: bool = False |
| 33 | + |
| 34 | + |
| 35 | +# Define a map for module filter functions based on model name |
| 36 | +MODULE_FILTER_MAP = { |
| 37 | + "llama": lambda mod, fqn: isinstance(mod, nn.Linear) and "mlp" in fqn and "lm_head" not in fqn, |
| 38 | + "mixtral": lambda mod, fqn: isinstance(mod, nn.Linear) |
| 39 | + and "block_sparse_moe" in fqn |
| 40 | + and "block_sparse_moe.gate" not in fqn |
| 41 | + and "lm_head" not in fqn, |
| 42 | + "default": lambda mod, fqn: isinstance(mod, nn.Linear), # Default filter |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +class Float8TrainingHandler: |
| 47 | + """Handler for configuring models for FP8 training using torchao.""" |
| 48 | + |
| 49 | + def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bool]): |
| 50 | + """Initializes the handler for FP8 training and configuration. |
| 51 | +
|
| 52 | + Args: |
| 53 | + args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile. |
| 54 | + model_path (str): The path to the model. Typically used for determining model-specific settings. |
| 55 | + parallel_dims (Dict[str, bool]): Dictionary specifying parallelization settings, such as whether DP shard is enabled. |
| 56 | +
|
| 57 | + Example Usage: |
| 58 | + fp8_config = FP8Config( |
| 59 | + enable_fp8=True, |
| 60 | + enable_amax_init=True, |
| 61 | + scaling_type_input="delayed", |
| 62 | + scaling_type_weight="delayed", |
| 63 | + scaling_type_grad_output="delayed", |
| 64 | + enable_fsdp_float8_all_gather=False, |
| 65 | + precompute_float8_dynamic_scale_for_fsdp=False, |
| 66 | + pad_inner_dim=True, |
| 67 | + emulate_fp8=False, # Set to True for testing without FP8 hardware |
| 68 | + enable_torch_compile=True, |
| 69 | + enable_pre_and_post_forward=False, |
| 70 | + ) |
| 71 | +
|
| 72 | + parallel_dims = {"dp_shard_enabled": False} |
| 73 | + handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims) |
| 74 | +
|
| 75 | + """ |
| 76 | + self.model_path = model_path |
| 77 | + self.args = args |
| 78 | + self.parallel_dims = parallel_dims |
| 79 | + self.compile = args.enable_torch_compile |
| 80 | + self.enable_fp8 = args.enable_fp8 |
| 81 | + |
| 82 | + if not self.enable_fp8: |
| 83 | + log.warning("Fp8 is disabled here") |
| 84 | + return |
| 85 | + |
| 86 | + if not is_sm89_or_later() and not args.emulate_fp8: |
| 87 | + log.error("Failed to swap to Float8Linear because float8 is only supported on SM89 or later (H100+ GPUs)") |
| 88 | + raise RuntimeError("Float8Linear operation is not supported on the current hardware.") |
| 89 | + |
| 90 | + # Check if torchao is installed and version is >= 0.5.0 |
| 91 | + try: |
| 92 | + compare_version("torchao", operator.ge, "0.6.1") |
| 93 | + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType |
| 94 | + except ImportError as e: |
| 95 | + log.error(str(e)) |
| 96 | + raise |
| 97 | + |
| 98 | + # Configure Float8LinearConfig parameters from args |
| 99 | + scaling_type_input = ScalingType(args.scaling_type_input) |
| 100 | + scaling_type_weight = ScalingType(args.scaling_type_weight) |
| 101 | + scaling_type_grad_output = ScalingType(args.scaling_type_grad_output) |
| 102 | + |
| 103 | + enable_fsdp_float8_all_gather = ( |
| 104 | + parallel_dims.get("dp_shard_enabled", False) and args.enable_fsdp_float8_all_gather |
| 105 | + ) |
| 106 | + |
| 107 | + enable_amax_init = args.enable_amax_init |
| 108 | + self.config = Float8LinearConfig( |
| 109 | + enable_amax_init=enable_amax_init, |
| 110 | + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, |
| 111 | + cast_config_input=CastConfig(scaling_type=scaling_type_input), |
| 112 | + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), |
| 113 | + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), |
| 114 | + enable_pre_and_post_forward=args.enable_pre_and_post_forward, |
| 115 | + pad_inner_dim=args.pad_inner_dim, |
| 116 | + emulate=args.emulate_fp8, |
| 117 | + ) |
| 118 | + |
| 119 | + # For precompute_float8_dynamic_scale_for_fsdp |
| 120 | + self.precompute_scale = enable_fsdp_float8_all_gather and args.precompute_float8_dynamic_scale_for_fsdp |
| 121 | + |
| 122 | + # For sync_float8_amax_and_scale_history |
| 123 | + self.delayed_scaling = ( |
| 124 | + scaling_type_input == ScalingType.DELAYED |
| 125 | + or scaling_type_weight == ScalingType.DELAYED |
| 126 | + or scaling_type_grad_output == ScalingType.DELAYED |
| 127 | + ) |
| 128 | + self._sync_float8_amax_and_scale_history = None |
| 129 | + |
| 130 | + log.info("Float8 training active") |
| 131 | + |
| 132 | + def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None): |
| 133 | + """Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model |
| 134 | + in place. |
| 135 | +
|
| 136 | + Args: |
| 137 | + model (nn.Module): The model whose layers should be converted. |
| 138 | + module_filter_fn (callable, optional): A function to filter which modules should be replaced. |
| 139 | + Defaults to a model-specific filter based on `model_path`. |
| 140 | +
|
| 141 | + """ |
| 142 | + if not self.enable_fp8: |
| 143 | + log.warning("FP8 is disabled, so layers will not be replaced.") |
| 144 | + return |
| 145 | + |
| 146 | + log.warning("Enabling FP8 Training") |
| 147 | + |
| 148 | + # Use the provided filter function or select from the map |
| 149 | + if module_filter_fn is None: |
| 150 | + model_path_lower = self.model_path.lower() |
| 151 | + module_filter_fn = next( |
| 152 | + (fn for key, fn in MODULE_FILTER_MAP.items() if key in model_path_lower), |
| 153 | + MODULE_FILTER_MAP["default"], # Default filter if no match is found |
| 154 | + ) |
| 155 | + |
| 156 | + from torchao.float8 import convert_to_float8_training |
| 157 | + |
| 158 | + convert_to_float8_training( |
| 159 | + model, |
| 160 | + config=self.config, |
| 161 | + module_filter_fn=module_filter_fn, |
| 162 | + ) |
| 163 | + log.info( |
| 164 | + f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}" |
| 165 | + ) |
| 166 | + |
| 167 | + def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, List[nn.Module]]): |
| 168 | + if not self.enable_fp8 or not self.precompute_scale: |
| 169 | + return |
| 170 | + |
| 171 | + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
| 172 | + |
| 173 | + models = [model] if isinstance(model, nn.Module) else model |
| 174 | + for m in models: |
| 175 | + precompute_float8_dynamic_scale_for_fsdp(m) |
| 176 | + |
| 177 | + def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, List[nn.Module]]): |
| 178 | + if not self.enable_fp8 or not self.delayed_scaling: |
| 179 | + return |
| 180 | + |
| 181 | + from torchao.float8 import sync_float8_amax_and_scale_history |
| 182 | + |
| 183 | + # Cache the compiled function if necessary |
| 184 | + if self._sync_float8_amax_and_scale_history is None: |
| 185 | + if self.compile: |
| 186 | + self._sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history) |
| 187 | + else: |
| 188 | + self._sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history |
| 189 | + |
| 190 | + models = [model] if isinstance(model, nn.Module) else model |
| 191 | + for m in models: |
| 192 | + self._sync_float8_amax_and_scale_history(m) |
0 commit comments