Skip to content

Support AMP with TPUs #17927

@carmocca

Description

@carmocca

Description & Motivation

Lightning currently supports accelerator="tpu", precision="bf16-mixed", but so far, this just sets the XLA_USE_BF16 environment variable:

Side note: why does Fabric also move the data to bf16?

The XLA team added support for automatic mixed precision (AMP). XLA:GPU uses a GradScaler and the autocast context manager, whereas XLA:TPU just uses the latter: https://github.com/pytorch/xla/blob/c9f2d91a234cdaf91f0bbdb044ec94e297ac839a/test/test_train_mp_mnist_amp.py#L143-L147

Pitch

Integrate from torch_xla.amp import autocast, GradScaler

The code would be very similar to the non-XLA AMP plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/amp.py

This would likely replace our existing XLABf16Precision plugin with an XLAMixedPrecision plugin.

Alternatives

This was just merged upstream. It's likely very experimental. I expect it will be released with PyTorch 2.1.

Additional context

PR on PyTorch: pytorch/pytorch#96370
PR on XLA: pytorch/xla#5161

cc @Borda @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions