Skip to content

Remove device_count for TPU launcher to avoid initializing runtime #3587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,6 @@ def deepspeed_launcher(args):

def tpu_launcher(args):
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count

if args.no_python:
raise ValueError("--no_python cannot be used with TPU launcher")
Expand All @@ -898,10 +897,6 @@ def tpu_launcher(args):
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
"different value to `--main_training_function`."
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
)
logger.warning("Launching training on all TPU cores.")

Like that? @SunMarc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In notebook_launcher function, we have that print("Launching a training on TPU cores.") that needs to be updated to print("Launching a training on all TPU cores.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think that we can just say to the users that if they pass args.num_processes, we can log a warning saying that he can't choose the number of devices and by default all devices are used. WDYT ?

if args.num_processes and args.num_processes != device_count():
raise ValueError(
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})"
)
Comment on lines -901 to -904
Copy link
Member

@SunMarc SunMarc May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are removing that, maybe we should put somewhere that we can running the script on all tpu cores available. -> print("Launching a training on all TPU cores.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct.


# Patch sys.argv
sys.argv = [mod.__file__] + args.training_script_args
Expand Down