Skip to content

Ungate FP8 + TP #2781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,6 @@ def _setup_model(
raise RuntimeError(
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
)
if self.tp_plan is not None:
raise ValueError(
"FP8 training does not support tensor parallelism yet. "
"This will be enabled in the near future."
)
if self.cp_degree > 1:
raise ValueError(
"Context Parallel for fp8 training is not currently supported"
Expand Down
11 changes: 5 additions & 6 deletions tests/torchtune/training/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from torchao.float8.float8_linear import Float8Linear

from torchtune.models.llama3 import base_llama_tp_plan
from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan
from torchtune.models.llama3 import base_llama_tp_plan, fp8_llama_tp_plan
from torchtune.training.quantization import (
_validate_float8_tp_plan,
convert_to_float8_training,
Expand Down Expand Up @@ -54,12 +53,12 @@ def _test_validate_float8_tp_plan(self):
"""
_validate_float8_tp_plan(base_llama_tp_plan())
_validate_float8_tp_plan(base_llama_tp_plan(), "anything")
_validate_float8_tp_plan(_fp8_llama_tp_plan())
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise")
_validate_float8_tp_plan(fp8_llama_tp_plan())
_validate_float8_tp_plan(fp8_llama_tp_plan(), "tensorwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise")
_validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp")
_validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise_with_gw_hp")

def test_is_fp8_tensorwise_scaling(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion torchtune/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
qlora_llama3_70b,
qlora_llama3_8b,
)
from ._parallelism import base_llama_tp_plan
from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan
from ._tokenizer import Llama3Tokenizer

__all__ = [
Expand All @@ -30,4 +30,5 @@
"qlora_llama3_8b",
"qlora_llama3_70b",
"base_llama_tp_plan",
"fp8_llama_tp_plan",
]
6 changes: 4 additions & 2 deletions torchtune/models/llama3/_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ def base_llama_tp_plan(
return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN


# TODO: expose this once tested
def _fp8_llama_tp_plan() -> dict[str, ParallelStyle]:
def fp8_llama_tp_plan(model: nn.Module) -> dict[str, ParallelStyle]:
"""
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.

Args:
model (nn.Module): Model to generate plan for (no-op)

Returns:
dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
"""
Expand Down
Loading