Skip to content

Commit 0fd3126

Browse files
JJJYmmmjohnmai-dev
andauthored
[MODEL] support qwen3.5 series w/o vision (#869)
* support text-only qwen3.5 series Co-authored-by: johnmai-dev <johnmai-dev@users.noreply.github.com> * add test * fix sanitize and add test * make it more readable * fix lint --------- Co-authored-by: johnmai-dev <johnmai-dev@users.noreply.github.com>
1 parent ca0d1c9 commit 0fd3126

File tree

3 files changed

+534
-0
lines changed

3 files changed

+534
-0
lines changed

mlx_lm/models/qwen3_5.py

Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
# Copyright © 2026 Apple Inc.
2+
3+
from dataclasses import dataclass, field
4+
from typing import Any, Dict, List, Optional, Union
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
from mlx.utils import tree_flatten, tree_unflatten
9+
10+
from .base import (
11+
BaseModelArgs,
12+
create_attention_mask,
13+
create_ssm_mask,
14+
)
15+
from .cache import ArraysCache, KVCache
16+
from .gated_delta import gated_delta_update
17+
from .qwen3_next import Qwen3NextAttention as Attention
18+
from .qwen3_next import Qwen3NextMLP as MLP
19+
from .qwen3_next import Qwen3NextRMSNormGated as RMSNormGated
20+
from .qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock
21+
22+
23+
@dataclass
24+
class TextModelArgs(BaseModelArgs):
25+
model_type: str = ""
26+
hidden_size: int = 4096
27+
intermediate_size: int = 14336
28+
num_hidden_layers: int = 32
29+
num_attention_heads: int = 32
30+
rms_norm_eps: float = 1e-6
31+
vocab_size: int = 151936
32+
num_key_value_heads: int = 8
33+
max_position_embeddings: int = 131072
34+
linear_num_value_heads: int = 64
35+
linear_num_key_heads: int = 16
36+
linear_key_head_dim: int = 192
37+
linear_value_head_dim: int = 128
38+
linear_conv_kernel_dim: int = 4
39+
tie_word_embeddings: bool = False
40+
attention_bias: bool = False
41+
head_dim: Optional[int] = None
42+
full_attention_interval: int = 4
43+
44+
# MoE fields (optional, for Qwen3_5MoeForConditionalGeneration)
45+
num_experts: int = 0
46+
num_experts_per_tok: int = 0
47+
decoder_sparse_step: int = 1
48+
shared_expert_intermediate_size: int = 0
49+
moe_intermediate_size: int = 0
50+
norm_topk_prob: bool = True
51+
52+
# Rope parameters
53+
rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field(
54+
default_factory=lambda: {
55+
"type": "default",
56+
"mrope_section": [11, 11, 10],
57+
"rope_theta": 100000,
58+
"partial_rotary_factor": 0.25,
59+
}
60+
)
61+
62+
# Derived from rope_parameters (set in __post_init__)
63+
partial_rotary_factor: float = 0.25
64+
rope_theta: float = 100000.0
65+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
66+
67+
def __post_init__(self):
68+
if self.head_dim is None:
69+
self.head_dim = self.hidden_size // self.num_attention_heads
70+
71+
if self.rope_parameters:
72+
if (
73+
"type" not in self.rope_parameters
74+
and "rope_type" in self.rope_parameters
75+
):
76+
self.rope_parameters["type"] = self.rope_parameters.pop("rope_type")
77+
78+
self.partial_rotary_factor = self.rope_parameters.get(
79+
"partial_rotary_factor", 0.25
80+
)
81+
self.rope_theta = self.rope_parameters.get("rope_theta", 100000.0)
82+
self.rope_scaling = self.rope_parameters
83+
84+
85+
class GatedDeltaNet(nn.Module):
86+
def __init__(self, config: TextModelArgs):
87+
super().__init__()
88+
self.hidden_size = config.hidden_size
89+
self.num_v_heads = config.linear_num_value_heads
90+
self.num_k_heads = config.linear_num_key_heads
91+
self.head_k_dim = config.linear_key_head_dim
92+
self.head_v_dim = config.linear_value_head_dim
93+
self.key_dim = self.head_k_dim * self.num_k_heads
94+
self.value_dim = self.head_v_dim * self.num_v_heads
95+
if self.num_v_heads % self.num_k_heads != 0:
96+
raise ValueError(
97+
f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})"
98+
)
99+
100+
self.conv_kernel_size = config.linear_conv_kernel_dim
101+
self.layer_norm_epsilon = config.rms_norm_eps
102+
103+
self.conv_dim = self.key_dim * 2 + self.value_dim
104+
self.conv1d = nn.Conv1d(
105+
in_channels=self.conv_dim,
106+
out_channels=self.conv_dim,
107+
bias=False,
108+
kernel_size=self.conv_kernel_size,
109+
groups=self.conv_dim,
110+
padding=0,
111+
)
112+
113+
self.in_proj_qkv = nn.Linear(
114+
self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False
115+
)
116+
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
117+
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
118+
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
119+
120+
self.dt_bias = mx.ones(self.num_v_heads)
121+
122+
A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,))
123+
self.A_log = mx.log(A)
124+
125+
self.norm = RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
126+
127+
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
128+
129+
def __call__(
130+
self,
131+
inputs: mx.array,
132+
mask: Optional[mx.array] = None,
133+
cache: Optional[Any] = None,
134+
) -> mx.array:
135+
B, S, _ = inputs.shape
136+
137+
qkv = self.in_proj_qkv(inputs)
138+
z = self.in_proj_z(inputs).reshape(B, S, self.num_v_heads, self.head_v_dim)
139+
b = self.in_proj_b(inputs)
140+
a = self.in_proj_a(inputs)
141+
142+
if cache is not None and cache[0] is not None:
143+
conv_state = cache[0]
144+
else:
145+
conv_state = mx.zeros(
146+
(B, self.conv_kernel_size - 1, self.conv_dim),
147+
dtype=inputs.dtype,
148+
)
149+
150+
if mask is not None:
151+
qkv = mx.where(mask[..., None], qkv, 0)
152+
conv_input = mx.concatenate([conv_state, qkv], axis=1)
153+
if cache is not None:
154+
cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :]
155+
conv_out = nn.silu(self.conv1d(conv_input))
156+
157+
q, k, v = [
158+
t.reshape(B, S, h, d)
159+
for t, h, d in zip(
160+
mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1),
161+
[self.num_k_heads, self.num_k_heads, self.num_v_heads],
162+
[self.head_k_dim, self.head_k_dim, self.head_v_dim],
163+
)
164+
]
165+
166+
state = cache[1] if cache else None
167+
inv_scale = k.shape[-1] ** -0.5
168+
q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6)
169+
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)
170+
171+
out, state = gated_delta_update(
172+
q,
173+
k,
174+
v,
175+
a,
176+
b,
177+
self.A_log,
178+
self.dt_bias,
179+
state,
180+
mask,
181+
use_kernel=not self.training,
182+
)
183+
184+
if cache is not None:
185+
cache[1] = state
186+
187+
out = self.norm(out, z)
188+
return self.out_proj(out.reshape(B, S, -1))
189+
190+
191+
class DecoderLayer(nn.Module):
192+
def __init__(self, args: TextModelArgs, layer_idx: int):
193+
super().__init__()
194+
self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0
195+
if self.is_linear:
196+
self.linear_attn = GatedDeltaNet(args)
197+
else:
198+
self.self_attn = Attention(args)
199+
200+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
201+
self.post_attention_layernorm = nn.RMSNorm(
202+
args.hidden_size, eps=args.rms_norm_eps
203+
)
204+
205+
if args.num_experts > 0:
206+
self.mlp = SparseMoeBlock(args)
207+
else:
208+
self.mlp = MLP(args.hidden_size, args.intermediate_size)
209+
210+
def __call__(
211+
self,
212+
x: mx.array,
213+
mask: Optional[mx.array] = None,
214+
cache: Optional[Any] = None,
215+
) -> mx.array:
216+
if self.is_linear:
217+
r = self.linear_attn(self.input_layernorm(x), mask, cache)
218+
else:
219+
r = self.self_attn(self.input_layernorm(x), mask, cache)
220+
h = x + r
221+
out = h + self.mlp(self.post_attention_layernorm(h))
222+
return out
223+
224+
225+
class Qwen3_5TextModel(nn.Module):
226+
def __init__(self, args: TextModelArgs):
227+
super().__init__()
228+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
229+
self.layers = [
230+
DecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers)
231+
]
232+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
233+
self.ssm_idx = 0
234+
self.fa_idx = args.full_attention_interval - 1
235+
236+
def __call__(
237+
self,
238+
inputs: mx.array,
239+
cache: Optional[Any] = None,
240+
input_embeddings: Optional[mx.array] = None,
241+
) -> mx.array:
242+
if input_embeddings is not None:
243+
hidden_states = input_embeddings
244+
else:
245+
hidden_states = self.embed_tokens(inputs)
246+
247+
if cache is None:
248+
cache = [None] * len(self.layers)
249+
250+
fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
251+
ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx])
252+
253+
for layer, c in zip(self.layers, cache):
254+
mask = ssm_mask if layer.is_linear else fa_mask
255+
hidden_states = layer(hidden_states, mask=mask, cache=c)
256+
257+
return self.norm(hidden_states)
258+
259+
260+
class TextModel(nn.Module):
261+
def __init__(self, args: TextModelArgs):
262+
super().__init__()
263+
self.args = args
264+
self.model_type = args.model_type
265+
self.model = Qwen3_5TextModel(args)
266+
if not args.tie_word_embeddings:
267+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
268+
269+
def __call__(
270+
self,
271+
inputs: mx.array,
272+
cache: Optional[Any] = None,
273+
input_embeddings: Optional[mx.array] = None,
274+
) -> mx.array:
275+
out = self.model(inputs, cache, input_embeddings=input_embeddings)
276+
if self.args.tie_word_embeddings:
277+
out = self.model.embed_tokens.as_linear(out)
278+
else:
279+
out = self.lm_head(out)
280+
return out
281+
282+
@property
283+
def layers(self):
284+
return self.model.layers
285+
286+
def make_cache(self):
287+
return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers]
288+
289+
def sanitize(self, weights):
290+
has_mtp_weights = any("mtp." in k for k in weights)
291+
has_unsanitized_conv1d = any(
292+
"conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items()
293+
)
294+
should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d
295+
weights = {k: v for k, v in weights.items() if "mtp." not in k}
296+
297+
if self.args.tie_word_embeddings:
298+
weights.pop("lm_head.weight", None)
299+
300+
norm_keys = (
301+
".input_layernorm.weight",
302+
".post_attention_layernorm.weight",
303+
"model.norm.weight",
304+
".q_norm.weight",
305+
".k_norm.weight",
306+
)
307+
for k, v in weights.items():
308+
if "conv1d.weight" in k and v.shape[-1] != 1:
309+
weights[k] = v.moveaxis(2, 1)
310+
if should_shift_norm_weights and any(k.endswith(sfx) for sfx in norm_keys):
311+
if v.ndim == 1:
312+
weights[k] = v + 1.0
313+
return weights
314+
315+
@property
316+
def quant_predicate(self):
317+
if self.args.num_experts <= 0:
318+
return None
319+
320+
def predicate(path, _):
321+
if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"):
322+
return {"group_size": 64, "bits": 8}
323+
return True
324+
325+
return predicate
326+
327+
328+
@dataclass
329+
class ModelArgs(BaseModelArgs):
330+
model_type: str
331+
text_config: dict
332+
333+
@classmethod
334+
def from_dict(cls, params):
335+
if "text_config" not in params:
336+
return cls(model_type=params["model_type"], text_config=params)
337+
return super().from_dict(params)
338+
339+
340+
class Model(nn.Module):
341+
def __init__(self, args: ModelArgs):
342+
super().__init__()
343+
self.args = args
344+
self.model_type = args.model_type
345+
self.language_model = TextModel(TextModelArgs.from_dict(args.text_config))
346+
347+
def __call__(
348+
self,
349+
inputs: mx.array,
350+
cache=None,
351+
input_embeddings: Optional[mx.array] = None,
352+
):
353+
return self.language_model(
354+
inputs, cache=cache, input_embeddings=input_embeddings
355+
)
356+
357+
def sanitize(self, weights):
358+
weights = tree_unflatten(list(weights.items()))
359+
weights = dict(tree_flatten(weights))
360+
361+
sanitized = {}
362+
for key, value in weights.items():
363+
if key.startswith("model.visual"):
364+
continue
365+
if key.startswith("model.language_model"):
366+
key = key.replace("model.language_model", "language_model.model")
367+
elif key.startswith("language_model."):
368+
pass
369+
else:
370+
key = "language_model." + key
371+
sanitized[key] = value
372+
return self.language_model.sanitize(sanitized)
373+
374+
@property
375+
def layers(self):
376+
return self.language_model.model.layers
377+
378+
def make_cache(self):
379+
return self.language_model.make_cache()
380+
381+
@property
382+
def quant_predicate(self):
383+
return self.language_model.quant_predicate

0 commit comments

Comments
 (0)