-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Remove device_count for TPU launcher to avoid initializing runtime #3587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -877,7 +877,6 @@ def deepspeed_launcher(args): | |
|
||
def tpu_launcher(args): | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
from torch_xla import device_count | ||
|
||
if args.no_python: | ||
raise ValueError("--no_python cannot be used with TPU launcher") | ||
|
@@ -898,10 +897,6 @@ def tpu_launcher(args): | |
f"Your training script should have a function named {args.main_training_function}, or you should pass a " | ||
"different value to `--main_training_function`." | ||
) | ||
if args.num_processes and args.num_processes != device_count(): | ||
raise ValueError( | ||
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})" | ||
) | ||
Comment on lines
-901
to
-904
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you are removing that, maybe we should put somewhere that we can running the script on all tpu cores available. -> print("Launching a training on all TPU cores.") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct. |
||
|
||
# Patch sys.argv | ||
sys.argv = [mod.__file__] + args.training_script_args | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like that? @SunMarc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
notebook_launcher
function, we have that print("Launching a training on TPU cores.") that needs to be updated to print("Launching a training on all TPU cores.")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think that we can just say to the users that if they pass
args.num_processes
, we can log a warning saying that he can't choose the number of devices and by default all devices are used. WDYT ?