Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def validate(
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(device_type)
# using sum instead of mean because we already rescale the
# loss_fn down by a factor of n_microbatches in
# torchtitan/distributed/pipeline_parallel.py
torch.sum(torch.stack(losses)).to(device_type)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=device_type)
)
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -82,7 +83,8 @@ def build_pipeline_schedule(
schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches),
scale_grads=False,
)
logger.info(
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/experiments/deepseek_v3/train_ds_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def run_full_model(
y = pp_schedule.step(x)
elif pp_rank == pp_size - 1:
y = pp_schedule.step(target=label, losses=losses)
loss = torch.mean(torch.stack(losses))
# using sum instead of mean because we already rescale the
# loss_fn down by a factor of n_microbatches in
# torchtitan/distributed/pipeline_parallel.py
loss = torch.sum(torch.stack(losses))
else:
pp_schedule.step()
else:
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def forward_backward_step(
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
# using sum instead of mean because we already rescale the
# loss_fn down by a factor of n_microbatches in
# torchtitan/distributed/pipeline_parallel.py
torch.sum(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def forward_backward_step(
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
# using sum instead of mean because we already rescale the
# loss_fn down by a factor of n_microbatches in
# torchtitan/distributed/pipeline_parallel.py
torch.sum(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
Expand Down
Loading