-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Re-apply Fixed issue with dot_product_attention when using TPU. #21254 after addressing cuDNN/FlashAttention API updates #21333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Corrected indentation in doc string
Fixed issue with passing a single image without batch dimension.
…scale.py Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com>
Test case for unbatched inputs
Testcase for checking both unbatched and batched single image inputs.
There was a bug, and it was causing cycle in graph.
removed the use of tree.map_structure
…s-team#21254)" (keras-team#21329) This reverts commit 81821e0.
Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR keras-team#2257 and addresses flash attention debugging needs.
… debugging" This reverts commit 7a0c547.
…sh_attention` Changes: - Add missing q_offsets=None and kv_offsets=None parameters to check_layout() call to match updated JAX function signature - Replace bare `except:` with `except Exception as e:` and `raise e` to preserve detailed error messages from JAX validation functions - Maintain existing fallback behavior when raise_error=False This resolves compatibility issues with newer JAX versions and improves debugging experience by surfacing specific technical details about flash attention compatibility failures.
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21333 +/- ##
==========================================
- Coverage 82.72% 82.67% -0.06%
==========================================
Files 565 565
Lines 54904 54941 +37
Branches 8520 8529 +9
==========================================
+ Hits 45418 45421 +3
- Misses 7399 7433 +34
Partials 2087 2087
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the fix.
After further debugging, I can verify that this is not a bug in JAX either. https://developer.nvidia.com/cuda-gpus#:~:text=8.0,A100%0ANVIDIA%20A30 A100 is Ampere series, compute capability 8.0 |
I've tested it on Below is my experimentation setup Adding print statement in
Adding debug statements in
And adding debug statements in jax code
|
For Llama3.2 I'm getting the below output
this verifies flash_attention is working |
whereas for Gemma3 I get the below output
|
@divyashreepathihalli Please go through above comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Reapplies the original TPU fix for dot_product_attention
, updates to match the latest JAX API signature, and improves flash-attention error handling and documentation.
- Updates
_can_use_flash_attention
to pass the newq_offsets
/kv_offsets
tocheck_layout
. - Adds named exception catches to preserve error context and improves diagnostic messages.
- Reintroduces and extends the TPU‐optimized flash-attention path with sharding support and explanatory docstrings.
Comments suppressed due to low confidence (1)
keras/src/backend/jax/nn.py:1262
- The
math
module isn’t imported in this file, somath.sqrt
will raise aNameError
. Addimport math
at the top of the module.
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))
Hi Rahul! |
Hi @divyashreepathihalli, Test Results
|
This PR reapplies the changes from #21254 (“Fixed issue with dot_product_attention when using TPU”), which was previously reverted in #21329 due to a test failure involving the gemma2 Flash Attention test on A100 GPUs.
Root Cause Analysis:
_can_use_flash_attention
was still returningFalse
on A100 GPUs.check_layout
was failing because of missing parameters, due to a signature change in the JAX source code.check_layout
to match the new JAX API,_can_use_flash_attention
then failed incheck_is_flash_attention
inside JAX with the error:The head dim must be <= 128 and a multiple of 8, but got 256.
Conclusion:
The gemma2 and gemma3 models cannot use Flash Attention on GPU until JAX adds support for larger head dimensions. The original PR did not cause a regression; the limitation is due to upstream JAX/cuDNN constraints, not this code.
Changes in this PR:
_can_use_flash_attention
to use the correct signature forcheck_layout
, matching the latest JAX API.dot_product_attention
on TPU.Note:
This PR does not re-enable Flash Attention for gemma2/gemma3 on A100 GPUs—they remain unsupported by JAX at this time. See the upstream JAX issue for future support.