Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 94 additions & 41 deletions mlx_lm/models/longcat_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import CacheList, KVCache
from .mla import MultiLinear
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU

Expand Down Expand Up @@ -80,10 +81,11 @@ def __init__(self, args: ModelArgs):
bias=args.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_attention_heads * (self.qk_nope_head_dim + args.v_head_dim),
bias=False,
self.embed_q = MultiLinear(
self.qk_nope_head_dim, self.kv_lora_rank, self.num_attention_heads
)
self.unembed_out = MultiLinear(
self.kv_lora_rank, self.v_head_dim, self.num_attention_heads
)

self.o_proj = nn.Linear(
Expand Down Expand Up @@ -122,56 +124,59 @@ def __call__(
B, L, _ = x.shape

if self.q_lora_rank is None:
q_states = self.q_proj(x)
q = self.q_proj(x)
else:
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))

q_states = q_states.reshape(B, L, -1, self.qk_head_dim).transpose(0, 2, 1, 3)
q = q.reshape(B, L, self.num_attention_heads, self.qk_head_dim).transpose(
0, 2, 1, 3
)

if self.mla_scale_q_lora is not None:
q_states = q_states * self.mla_scale_q_lora
q = q * self.mla_scale_q_lora

q_pass, q_rot = mx.split(q_states, [self.qk_nope_head_dim], axis=-1)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)

compressed_kv = self.kv_a_proj_with_mqa(x)
k_pass, k_rot = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pass = self.kv_a_layernorm(k_pass)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv_latent = self.kv_a_layernorm(compressed_kv)

if self.mla_scale_kv_lora is not None:
k_pass = k_pass * self.mla_scale_kv_lora
kv_latent = kv_latent * self.mla_scale_kv_lora

key_shape = (B, L, -1, self.qk_nope_head_dim + self.v_head_dim)
k_pass = self.kv_b_proj(k_pass).reshape(*key_shape).transpose(0, 2, 1, 3)
k_pass, value_states = mx.split(k_pass, [self.qk_nope_head_dim], axis=-1)
offset = cache.offset if cache is not None else 0
q_pe = self.rope(q_pe, offset)
k_pe = self.rope(k_pe, offset)

k_rot = k_rot.reshape(B, 1, L, self.qk_rope_head_dim)
kv_latent = mx.expand_dims(kv_latent, axis=1)

if cache is not None:
q_rot = self.rope(q_rot, cache.offset)
k_rot = self.rope(k_rot, cache.offset)
else:
q_rot = self.rope(q_rot)
k_rot = self.rope(k_rot)

k_rot = mx.broadcast_to(k_rot, (*k_pass.shape[:-1], k_rot.shape[-1]))
kv_latent, k_pe = cache.update_and_fetch(kv_latent, k_pe)

pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
if mask is not None:
pe_scores = mx.where(
mask,
pe_scores,
mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype),
)

query_states = mx.concatenate([q_pass, q_rot], axis=-1)
key_states = mx.concatenate([k_pass, k_rot], axis=-1)
if L == 1:
q_nope = self.embed_q(q_nope)
k = v = kv_latent
else:
k = self.embed_q(kv_latent, transpose=False)
v = self.unembed_out(kv_latent)

if cache is not None:
key_states, value_states = cache.update_and_fetch(key_states, value_states)

attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
cache=cache,
scale=self.scale,
mask=mask,
output = scaled_dot_product_attention(
q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores
)
if L == 1:
output = self.unembed_out(output)

attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_output)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class LongcatFlashMLP(nn.Module):
Expand Down Expand Up @@ -339,7 +344,7 @@ def __call__(
if cache is None:
cache = [(None, None)] * self.num_layers

mask = create_attention_mask(h, cache[0][0])
mask = create_attention_mask(h, cache[0][0], return_array=True)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
Expand Down Expand Up @@ -395,6 +400,47 @@ def sanitize(self, weights):
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)

for l in range(self.args.num_layers):
for i in range(2):
prefix = f"model.layers.{l}.self_attn.{i}"
kv_b_key = f"{prefix}.kv_b_proj.weight"
if kv_b_key in weights:
num_heads = self.args.num_attention_heads
head_dim = self.args.qk_nope_head_dim + self.args.v_head_dim
quantized = f"{prefix}.kv_b_proj.scales" in weights
v = weights.pop(kv_b_key)

if quantized:
dims = self.args.kv_lora_rank
scales = weights.pop(f"{prefix}.kv_b_proj.scales")
biases = weights.pop(f"{prefix}.kv_b_proj.biases")
bits = (v.shape[-1] * 32) // dims
group_size = dims // scales.shape[-1]
v = mx.dequantize(
v, scales, biases, bits=bits, group_size=group_size
)

v = v.reshape(num_heads, head_dim, -1)
wk = mx.contiguous(
v[:, : self.args.qk_nope_head_dim, :].swapaxes(-1, -2)
)
wv = mx.contiguous(v[:, self.args.qk_nope_head_dim :, :])

if quantized:
wk, wk_s, wk_b = mx.quantize(
wk, bits=bits, group_size=group_size
)
wv, wv_s, wv_b = mx.quantize(
wv, bits=bits, group_size=group_size
)
weights[f"{prefix}.embed_q.scales"] = wk_s
weights[f"{prefix}.embed_q.biases"] = wk_b
weights[f"{prefix}.unembed_out.scales"] = wv_s
weights[f"{prefix}.unembed_out.biases"] = wv_b

weights[f"{prefix}.embed_q.weight"] = wk
weights[f"{prefix}.unembed_out.weight"] = wv

new_weights = {}
for k, v in weights.items():
if k.startswith("model.mtp"):
Expand All @@ -408,6 +454,7 @@ def make_cache(self):
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
N = group.size()
rank = group.rank()

for layer in self.model.layers:
for attn in layer.self_attn:
Expand All @@ -419,11 +466,17 @@ def shard(self, group: Optional[mx.distributed.Group] = None):
attn.q_b_proj = shard_linear(
attn.q_b_proj, "all-to-sharded", group=group
)
attn.kv_b_proj = shard_linear(
attn.kv_b_proj, "all-to-sharded", group=group
)
attn.o_proj = shard_linear(attn.o_proj, "sharded-to-all", group=group)
attn.num_attention_heads //= N
num_heads = attn.num_attention_heads
sh = rank * num_heads
eh = sh + num_heads

def shard_heads(w):
return w[sh:eh]

attn.embed_q.apply(shard_heads)
attn.unembed_out.apply(shard_heads)

for mlp in layer.mlps:
mlp.gate_proj = shard_linear(
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/longcat_flash_ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __call__(

h = self.ngram_embeddings(input_ids, cache=cache[0])

mask = create_attention_mask(h, cache[1][0])
mask = create_attention_mask(h, cache[1][0], return_array=True)

for layer, c in zip(self.layers, cache[1:]):
h = layer(h, mask, cache=c)
Expand Down
Loading