From 04cd682ac70adc8439162809baeeec244f816fba Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 6 May 2025 10:25:48 +0530 Subject: [PATCH 01/20] Update nn.py --- keras/src/backend/jax/nn.py | 154 +++++++++++++++++++++++++++--------- 1 file changed, 116 insertions(+), 38 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index ba3dbd103acb..9cb4a103ba65 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1126,6 +1126,9 @@ def wrap_flash_attention( decoder_segment_ids, custom_mask=None, attn_logits_soft_cap=None, + head_shards=1, + q_seq_shards=1, + ): if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( @@ -1135,7 +1138,6 @@ def wrap_flash_attention( if custom_mask is not None: mask = splash_attention_mask.NumpyMask(array=custom_mask) - else: mask = splash_attention_mask.CausalMask( shape=(query.shape[2], query.shape[2]) @@ -1177,47 +1179,120 @@ def dot_product_attention( f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) + + # Check platform + platform = jax.devices()[0].platform + is_tpu = platform == "tpu" + + # Check if inputs use partial sharding (not fully replicated) + # Flash attention works well with fully replicated tensors on all platforms + # but may have issues with certain partial sharding patterns on non-TPU platforms + partially_sharded_inputs = any( + hasattr(t, "sharding") and not t.sharding.is_fully_replicated + for t in (query, key, value) + ) + + # Determine flash attention compatibility if flash_attention is None: - flash_attention = _can_use_flash_attention(query, key, value, bias) - elif flash_attention is True: - # Use `raise_error=True` to provide more details if the inputs failed to - # use flash attention - _can_use_flash_attention(query, key, value, bias, raise_error=True) - - if jax.devices()[0].platform == "tpu": - # Transpose to ('batch', 'heads', 'length', 'kv') - query = jnp.transpose(query, axes=(0, 2, 1, 3)) - key = jnp.transpose(key, axes=(0, 2, 1, 3)) - value = jnp.transpose(value, axes=(0, 2, 1, 3)) - B, H, S, KV = query.shape - - segment_ids = jnp.ones([B, S]) - # {token_ids, padding_mask, segment_ids} enable packing - out = wrap_flash_attention( - query, - key, - value, - decoder_segment_ids=splash_attention_kernel.SegmentIds( - segment_ids, segment_ids - ), - custom_mask=mask, - attn_logits_soft_cap=attn_logits_soft_cap, + # Auto-detect flash attention availability + if is_tpu: + # TPUs have specialized hardware for attention that works with any sharding pattern + flash_attention = True + else: + # For GPU/CPU with partially sharded inputs, we need multiple devices + # to efficiently handle the sharding + if partially_sharded_inputs and len(jax.devices()) <= 1: + flash_attention = False + else: + flash_attention = _can_use_flash_attention(query, key, value, bias) + elif flash_attention is True and not is_tpu: + # If flash attention is explicitly requested, validate compatibility + # Skip validation for TPU as it has specialized hardware support + try: + _can_use_flash_attention(query, key, value, bias, raise_error=True) + except Exception: + # Only disable flash attention on non-TPU platforms if validation fails + flash_attention = False + + # TPU-specific flash attention path + if is_tpu and flash_attention: + # Transpose to ('batch', 'heads', 'length', 'head_dim') + query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3)) + key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3)) + value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3)) + + bs, num_heads, q_len, head_dim = query_tpu_layout.shape + + # Apply scale to query if provided + if scale is not None: + # TPU kernel applies 1/sqrt(head_dim) internally, to achieve + # overall QK^T * scale, scale query by (scale * sqrt(head_dim)) + query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim)) + + # Create segment IDs for Splash Attention (for packing/batching) + segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32) + decoder_segment_ids = splash_attention_kernel.SegmentIds( + q=segment_ids, kv=segment_ids ) - out = jnp.transpose(out, axes=(0, 2, 1, 3)) - return out - # `dot_product_attention` is only available in jax>=0.4.31 + # Process mask for Splash Attention + custom_mask = None + if mask is not None: + mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask + + if mask_bool.ndim == 3 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0] + elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0, 0] + + if is_causal and custom_mask is not None: + causal_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + custom_mask = jnp.logical_and(custom_mask, causal_mask) + + if custom_mask is None and is_causal: + custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + + try: + output = wrap_flash_attention( + query_tpu_layout, + key_tpu_layout, + value_tpu_layout, + decoder_segment_ids=decoder_segment_ids, + custom_mask=custom_mask, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + # Transpose output back to Keras layout + return jnp.transpose(output, axes=(0, 2, 1, 3)) + except Exception: + flash_attention = False + + # JAX native dot_product_attention for GPU or fallback for TPU if hasattr(jax.nn, "dot_product_attention"): - return jax.nn.dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - implementation="cudnn" if flash_attention else "xla", - ) + try: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="cudnn" if flash_attention else "xla", + ) + except Exception: + # If flash attention fails, fall back to XLA implementation + if flash_attention: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="xla", + ) + raise if flash_attention: raise RuntimeError( @@ -1228,6 +1303,9 @@ def dot_product_attention( # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 # Not support `query_seq_lengths` and `key_value_seq_lengths` args + + # Fallback to custom XLA implementation + # This is the reference implementation from jax.nn.dot_product_attention output_shape = query.shape _, _, K, H = key.shape scale = (1.0 / jnp.sqrt(H)) if scale is None else scale From 1a7446523d1889f2515b3ab39a64c3b293ffe195 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 6 May 2025 10:47:14 +0530 Subject: [PATCH 02/20] Update nn.py --- keras/src/backend/jax/nn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 9cb4a103ba65..96cf87003b0c 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1128,7 +1128,6 @@ def wrap_flash_attention( attn_logits_soft_cap=None, head_shards=1, q_seq_shards=1, - ): if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( @@ -1149,8 +1148,8 @@ def wrap_flash_attention( ) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, + head_shards=head_shards, + q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, ) @@ -1169,6 +1168,8 @@ def dot_product_attention( is_causal=False, flash_attention=None, attn_logits_soft_cap=None, + head_shards=1, + q_seq_shards=1, ): query = convert_to_tensor(query) key = convert_to_tensor(key) @@ -1260,6 +1261,8 @@ def dot_product_attention( decoder_segment_ids=decoder_segment_ids, custom_mask=custom_mask, attn_logits_soft_cap=attn_logits_soft_cap, + head_shards=head_shards, # Pass the parameter value instead of hardcoding to 1 + q_seq_shards=q_seq_shards, # Pass the parameter value instead of hardcoding to 1 ) # Transpose output back to Keras layout return jnp.transpose(output, axes=(0, 2, 1, 3)) From c11eb819acd9601db5333a7e87eb45420e5c7e24 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 6 May 2025 11:16:29 +0530 Subject: [PATCH 03/20] Update nn.py --- keras/src/backend/jax/nn.py | 57 ++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 96cf87003b0c..e21d0169f3d6 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1168,9 +1168,34 @@ def dot_product_attention( is_causal=False, flash_attention=None, attn_logits_soft_cap=None, - head_shards=1, - q_seq_shards=1, ): + """Computes dot-product attention given query, key, and value. + + This is the core computation of attention that is used in transformers. + For TPU platforms, flash attention optimizations are automatically applied + when possible, and sharding parameters are inferred from the layout map + in the current distribution context. + + Args: + query: JAX Array or KerasTensor. Queries with shape `[batch, time, heads, depth_k]`. + key: JAX Array or KerasTensor. Keys with shape `[batch, time, heads, depth_k]`. + value: JAX Array or KerasTensor. Values with shape `[batch, time, heads, depth_v]`. + bias: JAX Array or KerasTensor. Optional bias with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + mask: JAX Array or KerasTensor. Optional mask with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + scale: Float. Optional scale that is applied to the attention computation. + is_causal: Boolean. Specifying whether causal masking is applied. + flash_attention: Boolean. Whether to use flash attention optimization for + increased performance. Default to None, which means it will be + auto-determined based on the platform, input shapes and compatibility. + attn_logits_soft_cap: Float. Optional float to softly cap attention logits to + avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + + Returns: + JAX Array of shape `[batch, time, heads, depth_v]`. + """ query = convert_to_tensor(query) key = convert_to_tensor(key) value = convert_to_tensor(value) @@ -1185,6 +1210,30 @@ def dot_product_attention( platform = jax.devices()[0].platform is_tpu = platform == "tpu" + # Get sharding parameters from distribution context + head_shards = 1 + q_seq_shards = 1 + + if is_tpu: + try: + from keras.src.distribution.distribution_lib import distribution as get_dist + from keras.src.distribution.distribution_lib import ModelParallel + + # Get current distribution if available + dist = get_dist() + if dist and isinstance(dist, ModelParallel): + mesh = dist.device_mesh + if "model" in mesh.axis_names: + model_dim_index = mesh.axis_names.index("model") + # Set head_shards based on the model dimension of the mesh + head_shards = mesh.shape[model_dim_index] + # Typically keep q_seq_shards=1 for best performance + q_seq_shards = 1 + except (ImportError, ValueError, AttributeError): + # Use default values if detection fails + head_shards = 1 + q_seq_shards = 1 + # Check if inputs use partial sharding (not fully replicated) # Flash attention works well with fully replicated tensors on all platforms # but may have issues with certain partial sharding patterns on non-TPU platforms @@ -1261,8 +1310,8 @@ def dot_product_attention( decoder_segment_ids=decoder_segment_ids, custom_mask=custom_mask, attn_logits_soft_cap=attn_logits_soft_cap, - head_shards=head_shards, # Pass the parameter value instead of hardcoding to 1 - q_seq_shards=q_seq_shards, # Pass the parameter value instead of hardcoding to 1 + head_shards=head_shards, + q_seq_shards=q_seq_shards, ) # Transpose output back to Keras layout return jnp.transpose(output, axes=(0, 2, 1, 3)) From c81e18c8e3bfdfd5c7288a242a8376cdb383a2b8 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 6 May 2025 12:09:55 +0530 Subject: [PATCH 04/20] Update nn.py --- keras/src/backend/jax/nn.py | 95 +++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index e21d0169f3d6..6cb9dedc8d1c 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -205,9 +205,9 @@ def _pool( initial_value: the initial value for the reduction. reduce_fn: a reduce function of the form `(T, T) -> T`. pool_size: a sequence of `N` integers, representing the window size to - reduce over. + reduce over. strides: a sequence of `N` integers, representing the inter-window - strides (default: `(1, ..., 1)`). + strides (default: `(1, ..., 1)`). padding: either the string `same` or `valid`. Returns: @@ -1131,8 +1131,8 @@ def wrap_flash_attention( ): if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( - "Sharding along sequence dimension not allowed in tpu kernel " - "attention" + "Sharding along sequence dimension not allowed" + " in tpu kernel attention" ) if custom_mask is not None: @@ -1148,8 +1148,8 @@ def wrap_flash_attention( ) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=head_shards, - q_seq_shards=q_seq_shards, + head_shards=head_shards, + q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, ) @@ -1170,28 +1170,32 @@ def dot_product_attention( attn_logits_soft_cap=None, ): """Computes dot-product attention given query, key, and value. - + This is the core computation of attention that is used in transformers. For TPU platforms, flash attention optimizations are automatically applied when possible, and sharding parameters are inferred from the layout map in the current distribution context. - + Args: - query: JAX Array or KerasTensor. Queries with shape `[batch, time, heads, depth_k]`. - key: JAX Array or KerasTensor. Keys with shape `[batch, time, heads, depth_k]`. - value: JAX Array or KerasTensor. Values with shape `[batch, time, heads, depth_v]`. - bias: JAX Array or KerasTensor. Optional bias with shape broadcastable to - `[batch, heads, dest_time, source_time]`. - mask: JAX Array or KerasTensor. Optional mask with shape broadcastable to - `[batch, heads, dest_time, source_time]`. - scale: Float. Optional scale that is applied to the attention computation. + query: Queries with shape `[batch, time, heads, + depth_k]`. + key: Keys with shape `[batch, time, heads, + depth_k]`. + value: Values with shape `[batch, time, heads, + depth_v]`. + bias: Optional bias with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + mask: Optional mask with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + scale: Float. Optional scale that is applied to the attention + computation. is_causal: Boolean. Specifying whether causal masking is applied. - flash_attention: Boolean. Whether to use flash attention optimization for - increased performance. Default to None, which means it will be - auto-determined based on the platform, input shapes and compatibility. - attn_logits_soft_cap: Float. Optional float to softly cap attention logits to - avoid numerical stability issues. Applied as: - `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + flash_attention: Boolean. Whether to use flash attention optimization + for increased performance. Default to None, which means it will be + auto-determined based on the platform, input shapes and compatibility. + attn_logits_soft_cap: Float. Optional float to softly cap attention + logits to avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. Returns: JAX Array of shape `[batch, time, heads, depth_v]`. @@ -1205,20 +1209,22 @@ def dot_product_attention( f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) - + # Check platform platform = jax.devices()[0].platform is_tpu = platform == "tpu" - + # Get sharding parameters from distribution context head_shards = 1 q_seq_shards = 1 - + if is_tpu: try: - from keras.src.distribution.distribution_lib import distribution as get_dist from keras.src.distribution.distribution_lib import ModelParallel - + from keras.src.distribution.distribution_lib import ( + distribution as get_dist, + ) + # Get current distribution if available dist = get_dist() if dist and isinstance(dist, ModelParallel): @@ -1233,37 +1239,42 @@ def dot_product_attention( # Use default values if detection fails head_shards = 1 q_seq_shards = 1 - + # Check if inputs use partial sharding (not fully replicated) # Flash attention works well with fully replicated tensors on all platforms - # but may have issues with certain partial sharding patterns on non-TPU platforms + # but may have issues with certain partial sharding patterns on non-TPU + # platforms partially_sharded_inputs = any( hasattr(t, "sharding") and not t.sharding.is_fully_replicated for t in (query, key, value) ) - + # Determine flash attention compatibility if flash_attention is None: # Auto-detect flash attention availability if is_tpu: - # TPUs have specialized hardware for attention that works with any sharding pattern + # TPUs have specialized hardware for attention that works with any + # sharding pattern flash_attention = True else: - # For GPU/CPU with partially sharded inputs, we need multiple devices - # to efficiently handle the sharding + # For GPU/CPU with partially sharded inputs, we need + # multiple devices to efficiently handle the sharding if partially_sharded_inputs and len(jax.devices()) <= 1: flash_attention = False else: - flash_attention = _can_use_flash_attention(query, key, value, bias) + flash_attention = _can_use_flash_attention( + query, key, value, bias + ) elif flash_attention is True and not is_tpu: # If flash attention is explicitly requested, validate compatibility # Skip validation for TPU as it has specialized hardware support try: _can_use_flash_attention(query, key, value, bias, raise_error=True) except Exception: - # Only disable flash attention on non-TPU platforms if validation fails + # Only disable flash attention on non-TPU platforms + # if validation fails flash_attention = False - + # TPU-specific flash attention path if is_tpu and flash_attention: # Transpose to ('batch', 'heads', 'length', 'head_dim') @@ -1275,7 +1286,7 @@ def dot_product_attention( # Apply scale to query if provided if scale is not None: - # TPU kernel applies 1/sqrt(head_dim) internally, to achieve + # TPU kernel applies 1/sqrt(head_dim) internally, to achieve # overall QK^T * scale, scale query by (scale * sqrt(head_dim)) query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim)) @@ -1289,16 +1300,18 @@ def dot_product_attention( custom_mask = None if mask is not None: mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask - + if mask_bool.ndim == 3 and mask_bool.shape[0] == bs: custom_mask = mask_bool[0] elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs: custom_mask = mask_bool[0, 0] if is_causal and custom_mask is not None: - causal_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + causal_mask = jnp.tril( + jnp.ones((q_len, q_len), dtype=jnp.bool_) + ) custom_mask = jnp.logical_and(custom_mask, causal_mask) - + if custom_mask is None and is_causal: custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) @@ -1355,7 +1368,7 @@ def dot_product_attention( # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 # Not support `query_seq_lengths` and `key_value_seq_lengths` args - + # Fallback to custom XLA implementation # This is the reference implementation from jax.nn.dot_product_attention output_shape = query.shape From d938e20c524b0df619986e43683030cb48ccfe6d Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 7 May 2025 10:27:30 +0530 Subject: [PATCH 05/20] Update nn.py Corrected indentation in doc string --- keras/src/backend/jax/nn.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 6cb9dedc8d1c..258bd0af7e04 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -205,9 +205,9 @@ def _pool( initial_value: the initial value for the reduction. reduce_fn: a reduce function of the form `(T, T) -> T`. pool_size: a sequence of `N` integers, representing the window size to - reduce over. + reduce over. strides: a sequence of `N` integers, representing the inter-window - strides (default: `(1, ..., 1)`). + strides (default: `(1, ..., 1)`). padding: either the string `same` or `valid`. Returns: @@ -1132,7 +1132,7 @@ def wrap_flash_attention( if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( "Sharding along sequence dimension not allowed" - " in tpu kernel attention" + " in TPU kernel attention" ) if custom_mask is not None: @@ -1178,24 +1178,24 @@ def dot_product_attention( Args: query: Queries with shape `[batch, time, heads, - depth_k]`. + depth_k]`. key: Keys with shape `[batch, time, heads, - depth_k]`. + depth_k]`. value: Values with shape `[batch, time, heads, - depth_v]`. + depth_v]`. bias: Optional bias with shape broadcastable to - `[batch, heads, dest_time, source_time]`. + `[batch, heads, dest_time, source_time]`. mask: Optional mask with shape broadcastable to - `[batch, heads, dest_time, source_time]`. + `[batch, heads, dest_time, source_time]`. scale: Float. Optional scale that is applied to the attention - computation. + computation. is_causal: Boolean. Specifying whether causal masking is applied. flash_attention: Boolean. Whether to use flash attention optimization for increased performance. Default to None, which means it will be - auto-determined based on the platform, input shapes and compatibility. + auto-determined based on the platform, input shapes and compatibility. attn_logits_soft_cap: Float. Optional float to softly cap attention - logits to avoid numerical stability issues. Applied as: - `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + logits to avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. Returns: JAX Array of shape `[batch, time, heads, depth_v]`. From f60811ef86c819aa7aee516fa20e4dfe44239b31 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 7 May 2025 10:41:11 +0530 Subject: [PATCH 06/20] Update nn.py --- keras/src/backend/jax/nn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 258bd0af7e04..cb2a7716c6ce 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1191,8 +1191,9 @@ def dot_product_attention( computation. is_causal: Boolean. Specifying whether causal masking is applied. flash_attention: Boolean. Whether to use flash attention optimization - for increased performance. Default to None, which means it will be - auto-determined based on the platform, input shapes and compatibility. + for increased performance. Default to None, which means it will + be auto-determined based on the platform, input shapes and + compatibility. attn_logits_soft_cap: Float. Optional float to softly cap attention logits to avoid numerical stability issues. Applied as: `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. From 28eeb2495fb9d72c3a93bf02dd2ae3a36ba26abd Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 12 May 2025 13:17:09 +0530 Subject: [PATCH 07/20] Update random_grayscale.py Fixed issue with passing a single image without batch dimension. --- .../image_preprocessing/random_grayscale.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 2dbcca6e5026..99ecf860eae5 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -59,12 +59,20 @@ def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): def get_random_transformation(self, images, training=True, seed=None): if seed is None: seed = self._get_seed_generator(self.backend._backend) - random_values = self.backend.random.uniform( - shape=(self.backend.core.shape(images)[0],), - minval=0, - maxval=1, - seed=seed, - ) + if len(images.shape) == 4: + random_values = self.backend.random.uniform( + shape=(self.backend.core.shape(images)[0],), + minval=0, + maxval=1, + seed=seed, + ) + else: + random_values = self.backend.random.uniform( + shape=(1,), + minval=0, + maxval=1, + seed=seed, + ) should_apply = self.backend.numpy.expand_dims( random_values < self.factor, axis=[1, 2, 3] ) From de81e5bc1da01e4a55aedc73fcf16725b6be3002 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR <55033230+pctablet505@users.noreply.github.com> Date: Mon, 12 May 2025 14:19:12 +0530 Subject: [PATCH 08/20] Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> --- .../image_preprocessing/random_grayscale.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 99ecf860eae5..31c911f476fa 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -59,16 +59,14 @@ def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): def get_random_transformation(self, images, training=True, seed=None): if seed is None: seed = self._get_seed_generator(self.backend._backend) + # Base case: Unbatched data + batch_size = 1 if len(images.shape) == 4: - random_values = self.backend.random.uniform( - shape=(self.backend.core.shape(images)[0],), - minval=0, - maxval=1, - seed=seed, - ) - else: - random_values = self.backend.random.uniform( - shape=(1,), + # This is a batch of images (4D input) + batch_size = self.backend.core.shape(images)[0] + + random_values = self.backend.random.uniform( + shape=(batch_size,), minval=0, maxval=1, seed=seed, From 66661ac827c0c43c46958ebe34a7ff0216e56216 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 12 May 2025 16:05:18 +0530 Subject: [PATCH 09/20] Update random_grayscale_test.py Test case for unbatched inputs --- .../random_grayscale_test.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py index b488c2c31f83..983554ef82aa 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py @@ -80,15 +80,28 @@ def test_grayscale_with_single_color_image(self): test_cases = [ (np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"), (np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"), + # unbatched inputs + (np.full((4, 4, 3), 128, dtype=np.float32), "channels_last"), + (np.full((3, 4, 4), 128, dtype=np.float32), "channels_first"), ] for xs, data_format in test_cases: layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) transformed = ops.convert_to_numpy(layer(xs)) - - if data_format == "channels_last": - unique_vals = np.unique(transformed[0, :, :, 0]) - self.assertEqual(len(unique_vals), 1) + + if len(xs.shape)==4: + # batched inputs + if data_format == "channels_last": + unique_vals = np.unique(transformed[0, :, :, 0]) + self.assertEqual(len(unique_vals), 1) + else: + unique_vals = np.unique(transformed[0, 0, :, :]) + self.assertEqual(len(unique_vals), 1) else: - unique_vals = np.unique(transformed[0, 0, :, :]) - self.assertEqual(len(unique_vals), 1) + # unbatched inputs + if data_format == "channels_last": + unique_vals = np.unique(transformed[ :, :, 0]) + self.assertEqual(len(unique_vals), 1) + else: + unique_vals = np.unique(transformed[ 0, :, :]) + self.assertEqual(len(unique_vals), 1) From c37f2b51c658fe0b5c981960ba5b629a718b1571 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 13 May 2025 11:59:55 +0530 Subject: [PATCH 10/20] code reformat --- .../image_preprocessing/random_grayscale.py | 12 ++++++------ .../image_preprocessing/random_grayscale_test.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 31c911f476fa..865c55a3ceeb 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -62,15 +62,15 @@ def get_random_transformation(self, images, training=True, seed=None): # Base case: Unbatched data batch_size = 1 if len(images.shape) == 4: - # This is a batch of images (4D input) + # This is a batch of images (4D input) batch_size = self.backend.core.shape(images)[0] random_values = self.backend.random.uniform( - shape=(batch_size,), - minval=0, - maxval=1, - seed=seed, - ) + shape=(batch_size,), + minval=0, + maxval=1, + seed=seed, + ) should_apply = self.backend.numpy.expand_dims( random_values < self.factor, axis=[1, 2, 3] ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py index 983554ef82aa..12ba46f275f4 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py @@ -88,8 +88,8 @@ def test_grayscale_with_single_color_image(self): for xs, data_format in test_cases: layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) transformed = ops.convert_to_numpy(layer(xs)) - - if len(xs.shape)==4: + + if len(xs.shape) == 4: # batched inputs if data_format == "channels_last": unique_vals = np.unique(transformed[0, :, :, 0]) @@ -100,8 +100,8 @@ def test_grayscale_with_single_color_image(self): else: # unbatched inputs if data_format == "channels_last": - unique_vals = np.unique(transformed[ :, :, 0]) + unique_vals = np.unique(transformed[:, :, 0]) self.assertEqual(len(unique_vals), 1) else: - unique_vals = np.unique(transformed[ 0, :, :]) + unique_vals = np.unique(transformed[0, :, :]) self.assertEqual(len(unique_vals), 1) From 498dece497053967fa09209f8ff9c3b052bb66b7 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 13 May 2025 13:09:53 +0530 Subject: [PATCH 11/20] Update random_grayscale_test.py Testcase for checking both unbatched and batched single image inputs. --- .../random_grayscale_test.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py index 12ba46f275f4..a43dfc55694a 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py @@ -78,6 +78,7 @@ def test_tf_data_compatibility(self): def test_grayscale_with_single_color_image(self): test_cases = [ + # batched inputs (np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"), (np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"), # unbatched inputs @@ -89,19 +90,25 @@ def test_grayscale_with_single_color_image(self): layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) transformed = ops.convert_to_numpy(layer(xs)) - if len(xs.shape) == 4: - # batched inputs - if data_format == "channels_last": - unique_vals = np.unique(transformed[0, :, :, 0]) - self.assertEqual(len(unique_vals), 1) - else: - unique_vals = np.unique(transformed[0, 0, :, :]) - self.assertEqual(len(unique_vals), 1) + # Determine if the input was batched + is_batched = len(xs.shape) == 4 + + # If batched, select the first image from the batch for inspection. + # Otherwise, use the transformed image directly. + # `image_to_inspect` will always be a 3D tensor. + if is_batched: + image_to_inspect = transformed[0] else: - # unbatched inputs - if data_format == "channels_last": - unique_vals = np.unique(transformed[:, :, 0]) - self.assertEqual(len(unique_vals), 1) - else: - unique_vals = np.unique(transformed[0, :, :]) - self.assertEqual(len(unique_vals), 1) + image_to_inspect = transformed + + if data_format == "channels_last": + # image_to_inspect has shape (H, W, C), + # get the first channel [:, :, 0] + channel_data = image_to_inspect[:, :, 0] + else: # data_format == "channels_first" + # image_to_inspect has shape (C, H, W), + # get the first channel [0, :, :] + channel_data = image_to_inspect[0, :, :] + + unique_vals = np.unique(channel_data) + self.assertEqual(len(unique_vals), 1) From 653f5b11d0762fc553c6849e41fd467fc383b66b Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 21 May 2025 11:10:20 +0530 Subject: [PATCH 12/20] changed compute_output_spec There was a bug, and it was causing cycle in graph. --- .../preprocessing/image_preprocessing/random_grayscale.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 865c55a3ceeb..ca693a246704 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -1,4 +1,5 @@ from keras.src import backend +from keras.src import tree from keras.src.api_export import keras_export from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 BaseImagePreprocessingLayer, @@ -96,7 +97,12 @@ def compute_output_shape(self, input_shape): return input_shape def compute_output_spec(self, inputs, **kwargs): - return inputs + return tree.map_structure( + lambda x: backend.KerasTensor( + x.shape, dtype=x.dtype, sparse=x.sparse + ), + inputs, + ) def transform_bounding_boxes(self, bounding_boxes, **kwargs): return bounding_boxes From 27ad80bf5c34583dec0c72d3f429af6257d24fad Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 26 May 2025 13:57:06 +0530 Subject: [PATCH 13/20] Update random_grayscale.py removed the use of tree.map_structure --- .../preprocessing/image_preprocessing/random_grayscale.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index ca693a246704..ca071d263de7 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -1,5 +1,4 @@ from keras.src import backend -from keras.src import tree from keras.src.api_export import keras_export from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 BaseImagePreprocessingLayer, @@ -97,11 +96,8 @@ def compute_output_shape(self, input_shape): return input_shape def compute_output_spec(self, inputs, **kwargs): - return tree.map_structure( - lambda x: backend.KerasTensor( - x.shape, dtype=x.dtype, sparse=x.sparse - ), - inputs, + return backend.KerasTensor( + inputs.shape, dtype=inputs.dtype, sparse=inputs.sparse ) def transform_bounding_boxes(self, bounding_boxes, **kwargs): From 579cc11d705a979d913c2b0840c3877d76db2533 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 29 May 2025 15:20:03 +0530 Subject: [PATCH 14/20] Reapply "Fixed issue with dot_product_attention when using TPU. (#21254)" (#21329) This reverts commit 81821e02486886436d10bb59bdfdf1715ebcca1a. --- keras/src/backend/jax/nn.py | 228 +++++++++++++++++++++++++++++------- 1 file changed, 186 insertions(+), 42 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index ba3dbd103acb..cb2a7716c6ce 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1126,16 +1126,17 @@ def wrap_flash_attention( decoder_segment_ids, custom_mask=None, attn_logits_soft_cap=None, + head_shards=1, + q_seq_shards=1, ): if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( - "Sharding along sequence dimension not allowed in tpu kernel " - "attention" + "Sharding along sequence dimension not allowed" + " in TPU kernel attention" ) if custom_mask is not None: mask = splash_attention_mask.NumpyMask(array=custom_mask) - else: mask = splash_attention_mask.CausalMask( shape=(query.shape[2], query.shape[2]) @@ -1147,8 +1148,8 @@ def wrap_flash_attention( ) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, + head_shards=head_shards, + q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, ) @@ -1168,6 +1169,38 @@ def dot_product_attention( flash_attention=None, attn_logits_soft_cap=None, ): + """Computes dot-product attention given query, key, and value. + + This is the core computation of attention that is used in transformers. + For TPU platforms, flash attention optimizations are automatically applied + when possible, and sharding parameters are inferred from the layout map + in the current distribution context. + + Args: + query: Queries with shape `[batch, time, heads, + depth_k]`. + key: Keys with shape `[batch, time, heads, + depth_k]`. + value: Values with shape `[batch, time, heads, + depth_v]`. + bias: Optional bias with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + mask: Optional mask with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + scale: Float. Optional scale that is applied to the attention + computation. + is_causal: Boolean. Specifying whether causal masking is applied. + flash_attention: Boolean. Whether to use flash attention optimization + for increased performance. Default to None, which means it will + be auto-determined based on the platform, input shapes and + compatibility. + attn_logits_soft_cap: Float. Optional float to softly cap attention + logits to avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + + Returns: + JAX Array of shape `[batch, time, heads, depth_v]`. + """ query = convert_to_tensor(query) key = convert_to_tensor(key) value = convert_to_tensor(value) @@ -1177,47 +1210,155 @@ def dot_product_attention( f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) - if flash_attention is None: - flash_attention = _can_use_flash_attention(query, key, value, bias) - elif flash_attention is True: - # Use `raise_error=True` to provide more details if the inputs failed to - # use flash attention - _can_use_flash_attention(query, key, value, bias, raise_error=True) - if jax.devices()[0].platform == "tpu": - # Transpose to ('batch', 'heads', 'length', 'kv') - query = jnp.transpose(query, axes=(0, 2, 1, 3)) - key = jnp.transpose(key, axes=(0, 2, 1, 3)) - value = jnp.transpose(value, axes=(0, 2, 1, 3)) - B, H, S, KV = query.shape - - segment_ids = jnp.ones([B, S]) - # {token_ids, padding_mask, segment_ids} enable packing - out = wrap_flash_attention( - query, - key, - value, - decoder_segment_ids=splash_attention_kernel.SegmentIds( - segment_ids, segment_ids - ), - custom_mask=mask, - attn_logits_soft_cap=attn_logits_soft_cap, + # Check platform + platform = jax.devices()[0].platform + is_tpu = platform == "tpu" + + # Get sharding parameters from distribution context + head_shards = 1 + q_seq_shards = 1 + + if is_tpu: + try: + from keras.src.distribution.distribution_lib import ModelParallel + from keras.src.distribution.distribution_lib import ( + distribution as get_dist, + ) + + # Get current distribution if available + dist = get_dist() + if dist and isinstance(dist, ModelParallel): + mesh = dist.device_mesh + if "model" in mesh.axis_names: + model_dim_index = mesh.axis_names.index("model") + # Set head_shards based on the model dimension of the mesh + head_shards = mesh.shape[model_dim_index] + # Typically keep q_seq_shards=1 for best performance + q_seq_shards = 1 + except (ImportError, ValueError, AttributeError): + # Use default values if detection fails + head_shards = 1 + q_seq_shards = 1 + + # Check if inputs use partial sharding (not fully replicated) + # Flash attention works well with fully replicated tensors on all platforms + # but may have issues with certain partial sharding patterns on non-TPU + # platforms + partially_sharded_inputs = any( + hasattr(t, "sharding") and not t.sharding.is_fully_replicated + for t in (query, key, value) + ) + + # Determine flash attention compatibility + if flash_attention is None: + # Auto-detect flash attention availability + if is_tpu: + # TPUs have specialized hardware for attention that works with any + # sharding pattern + flash_attention = True + else: + # For GPU/CPU with partially sharded inputs, we need + # multiple devices to efficiently handle the sharding + if partially_sharded_inputs and len(jax.devices()) <= 1: + flash_attention = False + else: + flash_attention = _can_use_flash_attention( + query, key, value, bias + ) + elif flash_attention is True and not is_tpu: + # If flash attention is explicitly requested, validate compatibility + # Skip validation for TPU as it has specialized hardware support + try: + _can_use_flash_attention(query, key, value, bias, raise_error=True) + except Exception: + # Only disable flash attention on non-TPU platforms + # if validation fails + flash_attention = False + + # TPU-specific flash attention path + if is_tpu and flash_attention: + # Transpose to ('batch', 'heads', 'length', 'head_dim') + query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3)) + key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3)) + value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3)) + + bs, num_heads, q_len, head_dim = query_tpu_layout.shape + + # Apply scale to query if provided + if scale is not None: + # TPU kernel applies 1/sqrt(head_dim) internally, to achieve + # overall QK^T * scale, scale query by (scale * sqrt(head_dim)) + query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim)) + + # Create segment IDs for Splash Attention (for packing/batching) + segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32) + decoder_segment_ids = splash_attention_kernel.SegmentIds( + q=segment_ids, kv=segment_ids ) - out = jnp.transpose(out, axes=(0, 2, 1, 3)) - return out - # `dot_product_attention` is only available in jax>=0.4.31 + # Process mask for Splash Attention + custom_mask = None + if mask is not None: + mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask + + if mask_bool.ndim == 3 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0] + elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0, 0] + + if is_causal and custom_mask is not None: + causal_mask = jnp.tril( + jnp.ones((q_len, q_len), dtype=jnp.bool_) + ) + custom_mask = jnp.logical_and(custom_mask, causal_mask) + + if custom_mask is None and is_causal: + custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + + try: + output = wrap_flash_attention( + query_tpu_layout, + key_tpu_layout, + value_tpu_layout, + decoder_segment_ids=decoder_segment_ids, + custom_mask=custom_mask, + attn_logits_soft_cap=attn_logits_soft_cap, + head_shards=head_shards, + q_seq_shards=q_seq_shards, + ) + # Transpose output back to Keras layout + return jnp.transpose(output, axes=(0, 2, 1, 3)) + except Exception: + flash_attention = False + + # JAX native dot_product_attention for GPU or fallback for TPU if hasattr(jax.nn, "dot_product_attention"): - return jax.nn.dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - implementation="cudnn" if flash_attention else "xla", - ) + try: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="cudnn" if flash_attention else "xla", + ) + except Exception: + # If flash attention fails, fall back to XLA implementation + if flash_attention: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="xla", + ) + raise if flash_attention: raise RuntimeError( @@ -1228,6 +1369,9 @@ def dot_product_attention( # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 # Not support `query_seq_lengths` and `key_value_seq_lengths` args + + # Fallback to custom XLA implementation + # This is the reference implementation from jax.nn.dot_product_attention output_shape = query.shape _, _, K, H = key.shape scale = (1.0 / jnp.sqrt(H)) if scale is None else scale From 7a0c5473c3091a2c90db031515c1c3f8daae8e7a Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 29 May 2025 15:35:49 +0530 Subject: [PATCH 15/20] Improve error handling in _can_use_flash_attention for better debugging Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR #2257 and addresses flash attention debugging needs. --- keras/src/backend/jax/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index cb2a7716c6ce..8eb06c301c5f 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1072,9 +1072,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): is_training=False, ) return True - except: + except Exception as e: if raise_error: - raise + raise e return False From f7a22907a4f47acd9619f7d9ab2aaf893e68354a Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 29 May 2025 15:39:01 +0530 Subject: [PATCH 16/20] Revert "Improve error handling in _can_use_flash_attention for better debugging" This reverts commit 7a0c5473c3091a2c90db031515c1c3f8daae8e7a. --- keras/src/backend/jax/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 8eb06c301c5f..cb2a7716c6ce 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1072,9 +1072,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): is_training=False, ) return True - except Exception as e: + except: if raise_error: - raise e + raise return False From 8bae8924329d9b61238521a3ae1f352157779232 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 29 May 2025 15:47:22 +0530 Subject: [PATCH 17/20] Fix JAX API compatibility and improve error handling in `_can_use_flash_attention` Changes: - Add missing q_offsets=None and kv_offsets=None parameters to check_layout() call to match updated JAX function signature - Replace bare `except:` with `except Exception as e:` and `raise e` to preserve detailed error messages from JAX validation functions - Maintain existing fallback behavior when raise_error=False This resolves compatibility issues with newer JAX versions and improves debugging experience by surfacing specific technical details about flash attention compatibility failures. --- keras/src/backend/jax/nn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index cb2a7716c6ce..3dbd06e6a292 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1062,6 +1062,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): q_seqlen=None, kv_seqlen=None, layout=_normalize_layout("BTNH"), + q_offsets=None, + kv_offsets=None, ) check_is_flash_attention( query, @@ -1072,9 +1074,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): is_training=False, ) return True - except: + except Exception as e: if raise_error: - raise + raise e return False From ee196cd1051135364931294c995cc693ceb59b87 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 29 May 2025 16:15:07 +0530 Subject: [PATCH 18/20] Updated `dot_product_attention` Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`. --- keras/src/backend/jax/nn.py | 52 +++++++------------------------------ 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 3dbd06e6a292..7f097d6e35e8 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1217,11 +1217,17 @@ def dot_product_attention( platform = jax.devices()[0].platform is_tpu = platform == "tpu" - # Get sharding parameters from distribution context - head_shards = 1 - q_seq_shards = 1 + # Determine flash attention compatibility + if flash_attention is None: + flash_attention = _can_use_flash_attention(query, key, value, bias) + elif flash_attention is True: + # Use `raise_error=True` to provide more details if the inputs failed to + # use flash attention + _can_use_flash_attention(query, key, value, bias, raise_error=True) - if is_tpu: + # TPU-specific flash attention path + if is_tpu and flash_attention: + # Get sharding parameters from distribution context try: from keras.src.distribution.distribution_lib import ModelParallel from keras.src.distribution.distribution_lib import ( @@ -1242,44 +1248,6 @@ def dot_product_attention( # Use default values if detection fails head_shards = 1 q_seq_shards = 1 - - # Check if inputs use partial sharding (not fully replicated) - # Flash attention works well with fully replicated tensors on all platforms - # but may have issues with certain partial sharding patterns on non-TPU - # platforms - partially_sharded_inputs = any( - hasattr(t, "sharding") and not t.sharding.is_fully_replicated - for t in (query, key, value) - ) - - # Determine flash attention compatibility - if flash_attention is None: - # Auto-detect flash attention availability - if is_tpu: - # TPUs have specialized hardware for attention that works with any - # sharding pattern - flash_attention = True - else: - # For GPU/CPU with partially sharded inputs, we need - # multiple devices to efficiently handle the sharding - if partially_sharded_inputs and len(jax.devices()) <= 1: - flash_attention = False - else: - flash_attention = _can_use_flash_attention( - query, key, value, bias - ) - elif flash_attention is True and not is_tpu: - # If flash attention is explicitly requested, validate compatibility - # Skip validation for TPU as it has specialized hardware support - try: - _can_use_flash_attention(query, key, value, bias, raise_error=True) - except Exception: - # Only disable flash attention on non-TPU platforms - # if validation fails - flash_attention = False - - # TPU-specific flash attention path - if is_tpu and flash_attention: # Transpose to ('batch', 'heads', 'length', 'head_dim') query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3)) key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3)) From 40583c886a541f454a371142fbd3d82a66a0bdff Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Sat, 7 Jun 2025 18:55:37 +0530 Subject: [PATCH 19/20] Update nn.py --- keras/src/backend/jax/nn.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 7f097d6e35e8..1a652539ffed 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1074,9 +1074,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): is_training=False, ) return True - except Exception as e: + except: if raise_error: - raise e + raise return False @@ -1121,7 +1121,7 @@ def _dot_product_attention_core( return jnp.einsum("BNTS,BSNH->BTNH", probs, value) -def wrap_flash_attention( +def wrap_flash_attention( query, key, value, @@ -1131,6 +1131,34 @@ def wrap_flash_attention( head_shards=1, q_seq_shards=1, ): + """ Applies a wrapped flash attention mechanism using the Splash kernel. + This function prepares the appropriate attention mask (causal or custom), + constructs a multi-head mask, and applies the Splash multi-head attention + kernel to the provided query, key, and value tensors. It supports optional + sharding and soft capping of attention logits. + Args: + query: jax.Array. The query tensor of shape + (batch, num_heads, seq_len, head_dim). + key: jax.Array. The key tensor of shape + (batch, num_heads, seq_len, head_dim). + value: jax.Array. The value tensor of shape + (batch, num_heads, seq_len, head_dim). + decoder_segment_ids: Optional. Segment IDs for the decoder, used for + sharding or masking. + custom_mask: Optional[jax.Array]. A custom attention mask to apply. If + None, a causal mask is used. + attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap + to the attention logits. + head_shards: int, default=1. Number of shards for the attention heads. + q_seq_shards: int, default=1. Number of shards for the query sequence + dimension. + Returns: + jax.Array: The result of applying the Splash multi-head attention + kernel to the inputs. + Raises: + AssertionError: If sharding along the sequence dimension is attempted + with decoder_segment_ids. + """ if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( "Sharding along sequence dimension not allowed" From 7c918badc6a371936f409495814e5668dc2ffe72 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Sat, 7 Jun 2025 19:12:28 +0530 Subject: [PATCH 20/20] Update nn.py --- keras/src/backend/jax/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 1a652539ffed..dbd91122ab84 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1121,7 +1121,7 @@ def _dot_product_attention_core( return jnp.einsum("BNTS,BSNH->BTNH", probs, value) -def wrap_flash_attention( +def wrap_flash_attention( query, key, value, @@ -1131,7 +1131,7 @@ def wrap_flash_attention( head_shards=1, q_seq_shards=1, ): - """ Applies a wrapped flash attention mechanism using the Splash kernel. + """Applies a wrapped flash attention mechanism using the Splash kernel. This function prepares the appropriate attention mask (causal or custom), constructs a multi-head mask, and applies the Splash multi-head attention kernel to the provided query, key, and value tensors. It supports optional