From fa378dcc3e6a3c539d66ecf1c605df73776b0af5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 10 Feb 2026 12:18:12 -0800 Subject: [PATCH] fix --- .../steel/attn/kernels/steel_attention.h | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 4de11b0819..491830949f 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -15,16 +15,6 @@ constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; -template -struct TransformScale { - T scale; - METAL_FUNC TransformScale(T scale_) : scale(scale_) {} - - METAL_FUNC T apply(T x) const { - return scale * x; - } -}; - struct MaxOp { template METAL_FUNC static constexpr T apply(T x, T y) { @@ -173,7 +163,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - TransformScale ts(static_cast(params->scale * M_LOG2E_F)); + const AccumType scale = params->scale * M_LOG2E_F; // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -216,13 +206,12 @@ template < threadgroup_barrier(mem_flags::mem_threadgroup); - // Load Q blocks apply scale + // Load Q blocks if (!align_Q && int(tid.x) == (params->NQ_aligned)) { loader_q.load_safe(short2(BD, params->qL_rem)); } else { loader_q.load_unsafe(); } - loader_q.apply_inplace_op(ts); // Init row reduction variables constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; @@ -281,6 +270,12 @@ template < tile_matmad(Stile, Qtile, Ktile, Stile); } + // Apply scale in float32 + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= scale; + } + // Mask out length sequence if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile);