diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index f82845bd64a..f93d38c8a57 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -877,7 +877,6 @@ def deepspeed_launcher(args): def tpu_launcher(args): import torch_xla.distributed.xla_multiprocessing as xmp - from torch_xla import device_count if args.no_python: raise ValueError("--no_python cannot be used with TPU launcher") @@ -898,10 +897,6 @@ def tpu_launcher(args): f"Your training script should have a function named {args.main_training_function}, or you should pass a " "different value to `--main_training_function`." ) - if args.num_processes and args.num_processes != device_count(): - raise ValueError( - f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})" - ) # Patch sys.argv sys.argv = [mod.__file__] + args.training_script_args