From 5c04db3a3c2374dd3b1da1113026dbee659e377b Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:38:08 +1000 Subject: [PATCH 1/5] Make fp8 llama tp plan take model param, remove private prefix --- torchtune/models/llama3/_parallelism.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index 8a2360c39f..3d1873925c 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -95,13 +95,15 @@ def base_llama_tp_plan( return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN -# TODO: expose this once tested -def _fp8_llama_tp_plan() -> dict[str, ParallelStyle]: +def fp8_llama_tp_plan(model: nn.Module) -> dict[str, ParallelStyle]: """ Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both rowwise and colwise computation, currently only compatible with float8 fine-tuning with "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. + Args: + model (nn.Module): Model to generate plan for (no-op) + Returns: dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. """ From fcd9aff0b2655594b14334b19749f92fa59fa181 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:38:35 +1000 Subject: [PATCH 2/5] Add fp8 llama3 tp plan to module import paths --- torchtune/models/llama3/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 5cf4e6b616..4ce89453d6 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -30,4 +30,5 @@ "qlora_llama3_8b", "qlora_llama3_70b", "base_llama_tp_plan", + "fp8_llama_tp_plan", ] From 1d406fe642fad2ca1c6b07885cede2ddc84f1118 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:39:59 +1000 Subject: [PATCH 3/5] Remove gate on fp8 + TP in full FT recipe --- recipes/full_finetune_distributed.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index d33d2dad31..b42e6aea49 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -603,11 +603,6 @@ def _setup_model( raise RuntimeError( "Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later." ) - if self.tp_plan is not None: - raise ValueError( - "FP8 training does not support tensor parallelism yet. " - "This will be enabled in the near future." - ) if self.cp_degree > 1: raise ValueError( "Context Parallel for fp8 training is not currently supported" From 54c440e3651f7fb13e2a2644b83afdfcfa2283ba Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Tue, 3 Jun 2025 13:29:52 +1000 Subject: [PATCH 4/5] Include import path in init --- torchtune/models/llama3/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 4ce89453d6..4091765ab7 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -15,7 +15,7 @@ qlora_llama3_70b, qlora_llama3_8b, ) -from ._parallelism import base_llama_tp_plan +from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan from ._tokenizer import Llama3Tokenizer __all__ = [ From 7de07c013a765a260ed97b03af3276e47dbfdb93 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Tue, 10 Jun 2025 07:23:16 +1000 Subject: [PATCH 5/5] fixed import paths --- tests/torchtune/training/test_quantization.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py index 6581dca99c..35ad54ce61 100644 --- a/tests/torchtune/training/test_quantization.py +++ b/tests/torchtune/training/test_quantization.py @@ -10,8 +10,7 @@ from torchao.float8.float8_linear import Float8Linear -from torchtune.models.llama3 import base_llama_tp_plan -from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan +from torchtune.models.llama3 import base_llama_tp_plan, fp8_llama_tp_plan from torchtune.training.quantization import ( _validate_float8_tp_plan, convert_to_float8_training, @@ -54,12 +53,12 @@ def _test_validate_float8_tp_plan(self): """ _validate_float8_tp_plan(base_llama_tp_plan()) _validate_float8_tp_plan(base_llama_tp_plan(), "anything") - _validate_float8_tp_plan(_fp8_llama_tp_plan()) - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise") + _validate_float8_tp_plan(fp8_llama_tp_plan()) + _validate_float8_tp_plan(fp8_llama_tp_plan(), "tensorwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise") + _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp") + _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise_with_gw_hp") def test_is_fp8_tensorwise_scaling(self): """