Skip to content

Conversation

@jack8558
Copy link
Collaborator

No description provided.

# Assign to a variable to prevent garbage collection before sync.
logits = model_tpu(input_ids).logits

torch_xla.sync() # Wait for the computation to complete.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't actually wait for computation to complete. It just launches the kernel on TPU and proceeds. I think that the right api is wait_device_ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank your for the info. This mean even for eager mode, we also need to call wait_device_ops each time we want to measure time, correct? (so we can wait until computation completes)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that if I use wait_device_ops for preheat timing, it became 35ms where using torch_xla.sync() gives me 3000ms. It feels like using wait_device_ops is not including the compilation time for initial run. Since we want to compare compilation time as well for first run, I will use torch_xla.sync() for preheat time.

@jack8558 jack8558 requested a review from bhavya01 October 15, 2025 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants