-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
Lightning currently supports accelerator="tpu", precision="bf16-mixed"
, but so far, this just sets the XLA_USE_BF16
environment variable:
- Trainer: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/plugins/precision/xlabf16.py
- Fabric: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/xlabf16.py
Side note: why does Fabric also move the data to bf16?
The XLA team added support for automatic mixed precision (AMP). XLA:GPU
uses a GradScaler
and the autocast
context manager, whereas XLA:TPU
just uses the latter: https://github.com/pytorch/xla/blob/c9f2d91a234cdaf91f0bbdb044ec94e297ac839a/test/test_train_mp_mnist_amp.py#L143-L147
Pitch
Integrate from torch_xla.amp import autocast, GradScaler
The code would be very similar to the non-XLA AMP plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/amp.py
This would likely replace our existing XLABf16Precision
plugin with an XLAMixedPrecision
plugin.
Alternatives
This was just merged upstream. It's likely very experimental. I expect it will be released with PyTorch 2.1.
Additional context
PR on PyTorch: pytorch/pytorch#96370
PR on XLA: pytorch/xla#5161
cc @Borda @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90