diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index a018c2cc7a8..3d6cad8f87d 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1062,6 +1062,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False): q_seqlen=None, kv_seqlen=None, layout=_normalize_layout("BTNH"), + q_offsets=None, + kv_offsets=None, ) check_is_flash_attention( query, @@ -1126,16 +1128,45 @@ def wrap_flash_attention( decoder_segment_ids, custom_mask=None, attn_logits_soft_cap=None, + head_shards=1, + q_seq_shards=1, ): + """Applies a wrapped flash attention mechanism using the Splash kernel. + This function prepares the appropriate attention mask (causal or custom), + constructs a multi-head mask, and applies the Splash multi-head attention + kernel to the provided query, key, and value tensors. It supports optional + sharding and soft capping of attention logits. + Args: + query: jax.Array. The query tensor of shape + (batch, num_heads, seq_len, head_dim). + key: jax.Array. The key tensor of shape + (batch, num_heads, seq_len, head_dim). + value: jax.Array. The value tensor of shape + (batch, num_heads, seq_len, head_dim). + decoder_segment_ids: Optional. Segment IDs for the decoder, used for + sharding or masking. + custom_mask: Optional[jax.Array]. A custom attention mask to apply. If + None, a causal mask is used. + attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap + to the attention logits. + head_shards: int, default=1. Number of shards for the attention heads. + q_seq_shards: int, default=1. Number of shards for the query sequence + dimension. + Returns: + jax.Array: The result of applying the Splash multi-head attention + kernel to the inputs. + Raises: + AssertionError: If sharding along the sequence dimension is attempted + with decoder_segment_ids. + """ if decoder_segment_ids is not None: assert query.shape[2] == decoder_segment_ids.q.shape[1], ( - "Sharding along sequence dimension not allowed in tpu kernel " - "attention" + "Sharding along sequence dimension not allowed" + " in TPU kernel attention" ) if custom_mask is not None: mask = splash_attention_mask.NumpyMask(array=custom_mask) - else: mask = splash_attention_mask.CausalMask( shape=(query.shape[2], query.shape[2]) @@ -1147,8 +1178,8 @@ def wrap_flash_attention( ) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, + head_shards=head_shards, + q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, ) @@ -1168,6 +1199,38 @@ def dot_product_attention( flash_attention=None, attn_logits_soft_cap=None, ): + """Computes dot-product attention given query, key, and value. + + This is the core computation of attention that is used in transformers. + For TPU platforms, flash attention optimizations are automatically applied + when possible, and sharding parameters are inferred from the layout map + in the current distribution context. + + Args: + query: Queries with shape `[batch, time, heads, + depth_k]`. + key: Keys with shape `[batch, time, heads, + depth_k]`. + value: Values with shape `[batch, time, heads, + depth_v]`. + bias: Optional bias with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + mask: Optional mask with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + scale: Float. Optional scale that is applied to the attention + computation. + is_causal: Boolean. Specifying whether causal masking is applied. + flash_attention: Boolean. Whether to use flash attention optimization + for increased performance. Default to None, which means it will + be auto-determined based on the platform, input shapes and + compatibility. + attn_logits_soft_cap: Float. Optional float to softly cap attention + logits to avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + + Returns: + JAX Array of shape `[batch, time, heads, depth_v]`. + """ query = convert_to_tensor(query) key = convert_to_tensor(key) value = convert_to_tensor(value) @@ -1177,6 +1240,12 @@ def dot_product_attention( f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) + + # Check platform + platform = jax.devices()[0].platform + is_tpu = platform == "tpu" + + # Determine flash attention compatibility if flash_attention is None: flash_attention = _can_use_flash_attention(query, key, value, bias) elif flash_attention is True: @@ -1184,40 +1253,110 @@ def dot_product_attention( # use flash attention _can_use_flash_attention(query, key, value, bias, raise_error=True) - if jax.devices()[0].platform == "tpu": - # Transpose to ('batch', 'heads', 'length', 'kv') - query = jnp.transpose(query, axes=(0, 2, 1, 3)) - key = jnp.transpose(key, axes=(0, 2, 1, 3)) - value = jnp.transpose(value, axes=(0, 2, 1, 3)) - B, H, S, KV = query.shape - - segment_ids = jnp.ones([B, S]) - # {token_ids, padding_mask, segment_ids} enable packing - out = wrap_flash_attention( - query, - key, - value, - decoder_segment_ids=splash_attention_kernel.SegmentIds( - segment_ids, segment_ids - ), - custom_mask=mask, - attn_logits_soft_cap=attn_logits_soft_cap, + # TPU-specific flash attention path + if is_tpu and flash_attention: + # Get sharding parameters from distribution context + try: + from keras.src.distribution.distribution_lib import ModelParallel + from keras.src.distribution.distribution_lib import ( + distribution as get_dist, + ) + + # Get current distribution if available + dist = get_dist() + if dist and isinstance(dist, ModelParallel): + mesh = dist.device_mesh + if "model" in mesh.axis_names: + model_dim_index = mesh.axis_names.index("model") + # Set head_shards based on the model dimension of the mesh + head_shards = mesh.shape[model_dim_index] + # Typically keep q_seq_shards=1 for best performance + q_seq_shards = 1 + except (ImportError, ValueError, AttributeError): + # Use default values if detection fails + head_shards = 1 + q_seq_shards = 1 + # Transpose to ('batch', 'heads', 'length', 'head_dim') + query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3)) + key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3)) + value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3)) + + bs, num_heads, q_len, head_dim = query_tpu_layout.shape + + # Apply scale to query if provided + if scale is not None: + # TPU kernel applies 1/sqrt(head_dim) internally, to achieve + # overall QK^T * scale, scale query by (scale * sqrt(head_dim)) + query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim)) + + # Create segment IDs for Splash Attention (for packing/batching) + segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32) + decoder_segment_ids = splash_attention_kernel.SegmentIds( + q=segment_ids, kv=segment_ids ) - out = jnp.transpose(out, axes=(0, 2, 1, 3)) - return out - # `dot_product_attention` is only available in jax>=0.4.31 + # Process mask for Splash Attention + custom_mask = None + if mask is not None: + mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask + + if mask_bool.ndim == 3 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0] + elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0, 0] + + if is_causal and custom_mask is not None: + causal_mask = jnp.tril( + jnp.ones((q_len, q_len), dtype=jnp.bool_) + ) + custom_mask = jnp.logical_and(custom_mask, causal_mask) + + if custom_mask is None and is_causal: + custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + + try: + output = wrap_flash_attention( + query_tpu_layout, + key_tpu_layout, + value_tpu_layout, + decoder_segment_ids=decoder_segment_ids, + custom_mask=custom_mask, + attn_logits_soft_cap=attn_logits_soft_cap, + head_shards=head_shards, + q_seq_shards=q_seq_shards, + ) + # Transpose output back to Keras layout + return jnp.transpose(output, axes=(0, 2, 1, 3)) + except Exception: + flash_attention = False + + # JAX native dot_product_attention for GPU or fallback for TPU if hasattr(jax.nn, "dot_product_attention"): - return jax.nn.dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - implementation="cudnn" if flash_attention else "xla", - ) + try: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="cudnn" if flash_attention else "xla", + ) + except Exception: + # If flash attention fails, fall back to XLA implementation + if flash_attention: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation="xla", + ) + raise if flash_attention: raise RuntimeError( @@ -1228,6 +1367,9 @@ def dot_product_attention( # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 # Not support `query_seq_lengths` and `key_value_seq_lengths` args + + # Fallback to custom XLA implementation + # This is the reference implementation from jax.nn.dot_product_attention output_shape = query.shape _, _, K, H = key.shape scale = (1.0 / jnp.sqrt(H)) if scale is None else scale