|
| 1 | +# Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import Any, Dict, List, Optional, Union |
| 5 | + |
| 6 | +import mlx.core as mx |
| 7 | +import mlx.nn as nn |
| 8 | +from mlx.utils import tree_flatten, tree_unflatten |
| 9 | + |
| 10 | +from .base import ( |
| 11 | + BaseModelArgs, |
| 12 | + create_attention_mask, |
| 13 | + create_ssm_mask, |
| 14 | +) |
| 15 | +from .cache import ArraysCache, KVCache |
| 16 | +from .gated_delta import gated_delta_update |
| 17 | +from .qwen3_next import Qwen3NextAttention as Attention |
| 18 | +from .qwen3_next import Qwen3NextMLP as MLP |
| 19 | +from .qwen3_next import Qwen3NextRMSNormGated as RMSNormGated |
| 20 | +from .qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class TextModelArgs(BaseModelArgs): |
| 25 | + model_type: str = "" |
| 26 | + hidden_size: int = 4096 |
| 27 | + intermediate_size: int = 14336 |
| 28 | + num_hidden_layers: int = 32 |
| 29 | + num_attention_heads: int = 32 |
| 30 | + rms_norm_eps: float = 1e-6 |
| 31 | + vocab_size: int = 151936 |
| 32 | + num_key_value_heads: int = 8 |
| 33 | + max_position_embeddings: int = 131072 |
| 34 | + linear_num_value_heads: int = 64 |
| 35 | + linear_num_key_heads: int = 16 |
| 36 | + linear_key_head_dim: int = 192 |
| 37 | + linear_value_head_dim: int = 128 |
| 38 | + linear_conv_kernel_dim: int = 4 |
| 39 | + tie_word_embeddings: bool = False |
| 40 | + attention_bias: bool = False |
| 41 | + head_dim: Optional[int] = None |
| 42 | + full_attention_interval: int = 4 |
| 43 | + |
| 44 | + # MoE fields (optional, for Qwen3_5MoeForConditionalGeneration) |
| 45 | + num_experts: int = 0 |
| 46 | + num_experts_per_tok: int = 0 |
| 47 | + decoder_sparse_step: int = 1 |
| 48 | + shared_expert_intermediate_size: int = 0 |
| 49 | + moe_intermediate_size: int = 0 |
| 50 | + norm_topk_prob: bool = True |
| 51 | + |
| 52 | + # Rope parameters |
| 53 | + rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( |
| 54 | + default_factory=lambda: { |
| 55 | + "type": "default", |
| 56 | + "mrope_section": [11, 11, 10], |
| 57 | + "rope_theta": 100000, |
| 58 | + "partial_rotary_factor": 0.25, |
| 59 | + } |
| 60 | + ) |
| 61 | + |
| 62 | + # Derived from rope_parameters (set in __post_init__) |
| 63 | + partial_rotary_factor: float = 0.25 |
| 64 | + rope_theta: float = 100000.0 |
| 65 | + rope_scaling: Optional[Dict[str, Union[float, str]]] = None |
| 66 | + |
| 67 | + def __post_init__(self): |
| 68 | + if self.head_dim is None: |
| 69 | + self.head_dim = self.hidden_size // self.num_attention_heads |
| 70 | + |
| 71 | + if self.rope_parameters: |
| 72 | + if ( |
| 73 | + "type" not in self.rope_parameters |
| 74 | + and "rope_type" in self.rope_parameters |
| 75 | + ): |
| 76 | + self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") |
| 77 | + |
| 78 | + self.partial_rotary_factor = self.rope_parameters.get( |
| 79 | + "partial_rotary_factor", 0.25 |
| 80 | + ) |
| 81 | + self.rope_theta = self.rope_parameters.get("rope_theta", 100000.0) |
| 82 | + self.rope_scaling = self.rope_parameters |
| 83 | + |
| 84 | + |
| 85 | +class GatedDeltaNet(nn.Module): |
| 86 | + def __init__(self, config: TextModelArgs): |
| 87 | + super().__init__() |
| 88 | + self.hidden_size = config.hidden_size |
| 89 | + self.num_v_heads = config.linear_num_value_heads |
| 90 | + self.num_k_heads = config.linear_num_key_heads |
| 91 | + self.head_k_dim = config.linear_key_head_dim |
| 92 | + self.head_v_dim = config.linear_value_head_dim |
| 93 | + self.key_dim = self.head_k_dim * self.num_k_heads |
| 94 | + self.value_dim = self.head_v_dim * self.num_v_heads |
| 95 | + if self.num_v_heads % self.num_k_heads != 0: |
| 96 | + raise ValueError( |
| 97 | + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" |
| 98 | + ) |
| 99 | + |
| 100 | + self.conv_kernel_size = config.linear_conv_kernel_dim |
| 101 | + self.layer_norm_epsilon = config.rms_norm_eps |
| 102 | + |
| 103 | + self.conv_dim = self.key_dim * 2 + self.value_dim |
| 104 | + self.conv1d = nn.Conv1d( |
| 105 | + in_channels=self.conv_dim, |
| 106 | + out_channels=self.conv_dim, |
| 107 | + bias=False, |
| 108 | + kernel_size=self.conv_kernel_size, |
| 109 | + groups=self.conv_dim, |
| 110 | + padding=0, |
| 111 | + ) |
| 112 | + |
| 113 | + self.in_proj_qkv = nn.Linear( |
| 114 | + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False |
| 115 | + ) |
| 116 | + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) |
| 117 | + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) |
| 118 | + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) |
| 119 | + |
| 120 | + self.dt_bias = mx.ones(self.num_v_heads) |
| 121 | + |
| 122 | + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) |
| 123 | + self.A_log = mx.log(A) |
| 124 | + |
| 125 | + self.norm = RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) |
| 126 | + |
| 127 | + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) |
| 128 | + |
| 129 | + def __call__( |
| 130 | + self, |
| 131 | + inputs: mx.array, |
| 132 | + mask: Optional[mx.array] = None, |
| 133 | + cache: Optional[Any] = None, |
| 134 | + ) -> mx.array: |
| 135 | + B, S, _ = inputs.shape |
| 136 | + |
| 137 | + qkv = self.in_proj_qkv(inputs) |
| 138 | + z = self.in_proj_z(inputs).reshape(B, S, self.num_v_heads, self.head_v_dim) |
| 139 | + b = self.in_proj_b(inputs) |
| 140 | + a = self.in_proj_a(inputs) |
| 141 | + |
| 142 | + if cache is not None and cache[0] is not None: |
| 143 | + conv_state = cache[0] |
| 144 | + else: |
| 145 | + conv_state = mx.zeros( |
| 146 | + (B, self.conv_kernel_size - 1, self.conv_dim), |
| 147 | + dtype=inputs.dtype, |
| 148 | + ) |
| 149 | + |
| 150 | + if mask is not None: |
| 151 | + qkv = mx.where(mask[..., None], qkv, 0) |
| 152 | + conv_input = mx.concatenate([conv_state, qkv], axis=1) |
| 153 | + if cache is not None: |
| 154 | + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] |
| 155 | + conv_out = nn.silu(self.conv1d(conv_input)) |
| 156 | + |
| 157 | + q, k, v = [ |
| 158 | + t.reshape(B, S, h, d) |
| 159 | + for t, h, d in zip( |
| 160 | + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), |
| 161 | + [self.num_k_heads, self.num_k_heads, self.num_v_heads], |
| 162 | + [self.head_k_dim, self.head_k_dim, self.head_v_dim], |
| 163 | + ) |
| 164 | + ] |
| 165 | + |
| 166 | + state = cache[1] if cache else None |
| 167 | + inv_scale = k.shape[-1] ** -0.5 |
| 168 | + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) |
| 169 | + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) |
| 170 | + |
| 171 | + out, state = gated_delta_update( |
| 172 | + q, |
| 173 | + k, |
| 174 | + v, |
| 175 | + a, |
| 176 | + b, |
| 177 | + self.A_log, |
| 178 | + self.dt_bias, |
| 179 | + state, |
| 180 | + mask, |
| 181 | + use_kernel=not self.training, |
| 182 | + ) |
| 183 | + |
| 184 | + if cache is not None: |
| 185 | + cache[1] = state |
| 186 | + |
| 187 | + out = self.norm(out, z) |
| 188 | + return self.out_proj(out.reshape(B, S, -1)) |
| 189 | + |
| 190 | + |
| 191 | +class DecoderLayer(nn.Module): |
| 192 | + def __init__(self, args: TextModelArgs, layer_idx: int): |
| 193 | + super().__init__() |
| 194 | + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 |
| 195 | + if self.is_linear: |
| 196 | + self.linear_attn = GatedDeltaNet(args) |
| 197 | + else: |
| 198 | + self.self_attn = Attention(args) |
| 199 | + |
| 200 | + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 201 | + self.post_attention_layernorm = nn.RMSNorm( |
| 202 | + args.hidden_size, eps=args.rms_norm_eps |
| 203 | + ) |
| 204 | + |
| 205 | + if args.num_experts > 0: |
| 206 | + self.mlp = SparseMoeBlock(args) |
| 207 | + else: |
| 208 | + self.mlp = MLP(args.hidden_size, args.intermediate_size) |
| 209 | + |
| 210 | + def __call__( |
| 211 | + self, |
| 212 | + x: mx.array, |
| 213 | + mask: Optional[mx.array] = None, |
| 214 | + cache: Optional[Any] = None, |
| 215 | + ) -> mx.array: |
| 216 | + if self.is_linear: |
| 217 | + r = self.linear_attn(self.input_layernorm(x), mask, cache) |
| 218 | + else: |
| 219 | + r = self.self_attn(self.input_layernorm(x), mask, cache) |
| 220 | + h = x + r |
| 221 | + out = h + self.mlp(self.post_attention_layernorm(h)) |
| 222 | + return out |
| 223 | + |
| 224 | + |
| 225 | +class Qwen3_5TextModel(nn.Module): |
| 226 | + def __init__(self, args: TextModelArgs): |
| 227 | + super().__init__() |
| 228 | + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) |
| 229 | + self.layers = [ |
| 230 | + DecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers) |
| 231 | + ] |
| 232 | + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 233 | + self.ssm_idx = 0 |
| 234 | + self.fa_idx = args.full_attention_interval - 1 |
| 235 | + |
| 236 | + def __call__( |
| 237 | + self, |
| 238 | + inputs: mx.array, |
| 239 | + cache: Optional[Any] = None, |
| 240 | + input_embeddings: Optional[mx.array] = None, |
| 241 | + ) -> mx.array: |
| 242 | + if input_embeddings is not None: |
| 243 | + hidden_states = input_embeddings |
| 244 | + else: |
| 245 | + hidden_states = self.embed_tokens(inputs) |
| 246 | + |
| 247 | + if cache is None: |
| 248 | + cache = [None] * len(self.layers) |
| 249 | + |
| 250 | + fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx]) |
| 251 | + ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx]) |
| 252 | + |
| 253 | + for layer, c in zip(self.layers, cache): |
| 254 | + mask = ssm_mask if layer.is_linear else fa_mask |
| 255 | + hidden_states = layer(hidden_states, mask=mask, cache=c) |
| 256 | + |
| 257 | + return self.norm(hidden_states) |
| 258 | + |
| 259 | + |
| 260 | +class TextModel(nn.Module): |
| 261 | + def __init__(self, args: TextModelArgs): |
| 262 | + super().__init__() |
| 263 | + self.args = args |
| 264 | + self.model_type = args.model_type |
| 265 | + self.model = Qwen3_5TextModel(args) |
| 266 | + if not args.tie_word_embeddings: |
| 267 | + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) |
| 268 | + |
| 269 | + def __call__( |
| 270 | + self, |
| 271 | + inputs: mx.array, |
| 272 | + cache: Optional[Any] = None, |
| 273 | + input_embeddings: Optional[mx.array] = None, |
| 274 | + ) -> mx.array: |
| 275 | + out = self.model(inputs, cache, input_embeddings=input_embeddings) |
| 276 | + if self.args.tie_word_embeddings: |
| 277 | + out = self.model.embed_tokens.as_linear(out) |
| 278 | + else: |
| 279 | + out = self.lm_head(out) |
| 280 | + return out |
| 281 | + |
| 282 | + @property |
| 283 | + def layers(self): |
| 284 | + return self.model.layers |
| 285 | + |
| 286 | + def make_cache(self): |
| 287 | + return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] |
| 288 | + |
| 289 | + def sanitize(self, weights): |
| 290 | + has_mtp_weights = any("mtp." in k for k in weights) |
| 291 | + has_unsanitized_conv1d = any( |
| 292 | + "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() |
| 293 | + ) |
| 294 | + should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d |
| 295 | + weights = {k: v for k, v in weights.items() if "mtp." not in k} |
| 296 | + |
| 297 | + if self.args.tie_word_embeddings: |
| 298 | + weights.pop("lm_head.weight", None) |
| 299 | + |
| 300 | + norm_keys = ( |
| 301 | + ".input_layernorm.weight", |
| 302 | + ".post_attention_layernorm.weight", |
| 303 | + "model.norm.weight", |
| 304 | + ".q_norm.weight", |
| 305 | + ".k_norm.weight", |
| 306 | + ) |
| 307 | + for k, v in weights.items(): |
| 308 | + if "conv1d.weight" in k and v.shape[-1] != 1: |
| 309 | + weights[k] = v.moveaxis(2, 1) |
| 310 | + if should_shift_norm_weights and any(k.endswith(sfx) for sfx in norm_keys): |
| 311 | + if v.ndim == 1: |
| 312 | + weights[k] = v + 1.0 |
| 313 | + return weights |
| 314 | + |
| 315 | + @property |
| 316 | + def quant_predicate(self): |
| 317 | + if self.args.num_experts <= 0: |
| 318 | + return None |
| 319 | + |
| 320 | + def predicate(path, _): |
| 321 | + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): |
| 322 | + return {"group_size": 64, "bits": 8} |
| 323 | + return True |
| 324 | + |
| 325 | + return predicate |
| 326 | + |
| 327 | + |
| 328 | +@dataclass |
| 329 | +class ModelArgs(BaseModelArgs): |
| 330 | + model_type: str |
| 331 | + text_config: dict |
| 332 | + |
| 333 | + @classmethod |
| 334 | + def from_dict(cls, params): |
| 335 | + if "text_config" not in params: |
| 336 | + return cls(model_type=params["model_type"], text_config=params) |
| 337 | + return super().from_dict(params) |
| 338 | + |
| 339 | + |
| 340 | +class Model(nn.Module): |
| 341 | + def __init__(self, args: ModelArgs): |
| 342 | + super().__init__() |
| 343 | + self.args = args |
| 344 | + self.model_type = args.model_type |
| 345 | + self.language_model = TextModel(TextModelArgs.from_dict(args.text_config)) |
| 346 | + |
| 347 | + def __call__( |
| 348 | + self, |
| 349 | + inputs: mx.array, |
| 350 | + cache=None, |
| 351 | + input_embeddings: Optional[mx.array] = None, |
| 352 | + ): |
| 353 | + return self.language_model( |
| 354 | + inputs, cache=cache, input_embeddings=input_embeddings |
| 355 | + ) |
| 356 | + |
| 357 | + def sanitize(self, weights): |
| 358 | + weights = tree_unflatten(list(weights.items())) |
| 359 | + weights = dict(tree_flatten(weights)) |
| 360 | + |
| 361 | + sanitized = {} |
| 362 | + for key, value in weights.items(): |
| 363 | + if key.startswith("model.visual"): |
| 364 | + continue |
| 365 | + if key.startswith("model.language_model"): |
| 366 | + key = key.replace("model.language_model", "language_model.model") |
| 367 | + elif key.startswith("language_model."): |
| 368 | + pass |
| 369 | + else: |
| 370 | + key = "language_model." + key |
| 371 | + sanitized[key] = value |
| 372 | + return self.language_model.sanitize(sanitized) |
| 373 | + |
| 374 | + @property |
| 375 | + def layers(self): |
| 376 | + return self.language_model.model.layers |
| 377 | + |
| 378 | + def make_cache(self): |
| 379 | + return self.language_model.make_cache() |
| 380 | + |
| 381 | + @property |
| 382 | + def quant_predicate(self): |
| 383 | + return self.language_model.quant_predicate |
0 commit comments