|
| 1 | +# Float8 MoE Training |
| 2 | + |
| 3 | +This prototype feature provides a way to use float8 rowwise training on MoE layers. |
| 4 | + |
| 5 | +Below is a simple runnable example of how to use this feature, using the MoE layer |
| 6 | +from the [torchtitan](https://github.com/pytorch/torchtitan) Llama4 implementation for demonstration. |
| 7 | + |
| 8 | + |
| 9 | +```python |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from torch.nn import functional as F |
| 13 | + |
| 14 | +# this feature requires CUDA and SM89+ |
| 15 | +assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
| 16 | + |
| 17 | +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig |
| 18 | +from torchao.quantization.quant_api import quantize_ |
| 19 | + |
| 20 | +# this example uses torchtitan llama4 MoE, see |
| 21 | +try: |
| 22 | + from torchtitan.experiments.llama4.model.args import TransformerModelArgs |
| 23 | + from torchtitan.experiments.llama4.model.moe import MoE |
| 24 | +except ImportError as e: |
| 25 | + raise ImportError( |
| 26 | + "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan" |
| 27 | + ) from e |
| 28 | + |
| 29 | + |
| 30 | +# initialize model |
| 31 | +device = torch.device("cuda") |
| 32 | +model_args = TransformerModelArgs( |
| 33 | + moe_enabled=True, |
| 34 | + num_experts=8, |
| 35 | + dim=256, |
| 36 | +) |
| 37 | +model = MoE(model_args).to(torch.bfloat16).to(device) |
| 38 | +init_std = 0.02 |
| 39 | +model.init_weights(init_std, device) |
| 40 | + |
| 41 | +# module filter function to define which modules to quantize |
| 42 | +target_fqns = ["experts"] |
| 43 | + |
| 44 | + |
| 45 | +def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: |
| 46 | + for target_fqn in target_fqns: |
| 47 | + if target_fqn in cur_fqn: |
| 48 | + return True |
| 49 | + return False |
| 50 | + |
| 51 | + |
| 52 | +# quantize the model |
| 53 | +config = MoETrainingConfig() |
| 54 | +quantize_(model, config=config, filter_fn=moe_module_filter_fn) |
| 55 | + |
| 56 | +# training loop |
| 57 | +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) |
| 58 | +for step in range(10): |
| 59 | + batch, seq, dim = 8, 2048, 256 |
| 60 | + x = torch.randn( |
| 61 | + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device |
| 62 | + ) |
| 63 | + |
| 64 | + # forward pass |
| 65 | + out = model(x) |
| 66 | + |
| 67 | + # compute loss |
| 68 | + labels = torch.ones_like(out) |
| 69 | + out_loss = F.mse_loss(out, labels) |
| 70 | + print(f"step {step} loss: {out_loss.item()}") |
| 71 | + |
| 72 | + # backward pass |
| 73 | + out_loss.backward() |
| 74 | + optimizer.step() |
| 75 | + |
| 76 | +``` |
| 77 | + |
| 78 | +## Requirements |
| 79 | +- torchao nightly build |
| 80 | +- CUDA compute capability 8.9+ (SM89+) |
| 81 | + |
| 82 | +## Modeling requirements |
| 83 | +This prototype is specifically designed to be used on MoE models using |
| 84 | +`torch._grouped_mm` to implement expert computation in token-choice routing, |
| 85 | +where expert weights are implemented as 3D nn.Parameters with `num_experts` as |
| 86 | +the leading dim. |
| 87 | + |
| 88 | +The `MoETrainingConfig` has a module handler registered to it which will |
| 89 | +find all nn.Parameters whose parent module matches the module filter function, |
| 90 | +and swap their data tensor with a ScaledGroupedMMTensor. |
| 91 | + |
| 92 | +The ScaledGroupedMMTensor is a tensor subclass which overrides the |
| 93 | +`torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm, |
| 94 | +which performs dynamic float8 rowwise quantization on scaled grouped GEMM |
| 95 | +operands in both the forward and backward pass. |
| 96 | + |
| 97 | +For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. |
| 98 | + |
| 99 | +## Limitations |
| 100 | +- Only tested with eager mode, single GPU training so far. |
| 101 | +- Composability with parallelisms and `torch.compile` are next steps. |
0 commit comments