Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@ constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];

template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}

METAL_FUNC T apply(T x) const {
return scale * x;
}
};

struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
Expand Down Expand Up @@ -173,7 +163,7 @@ template <
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);

TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
const AccumType scale = params->scale * M_LOG2E_F;

// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
Expand Down Expand Up @@ -216,13 +206,12 @@ template <

threadgroup_barrier(mem_flags::mem_threadgroup);

// Load Q blocks apply scale
// Load Q blocks
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL_rem));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);

// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
Expand Down Expand Up @@ -281,6 +270,12 @@ template <
tile_matmad(Stile, Qtile, Ktile, Stile);
}

// Apply scale in float32
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
Stile.elems()[ii] *= scale;
}

// Mask out length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
Expand Down
Loading