From ca50058921371e92741c6eb16045b88e7efc362e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 6 May 2025 18:39:56 +0100 Subject: [PATCH 01/25] wip --- torchtune/models/deepseek_v3/__init__.py | 5 + torchtune/models/deepseek_v3/_attention.py | 104 ++++++++++++++++ .../models/deepseek_v3/_component_builders.py | 112 ++++++++++++++++++ torchtune/models/deepseek_v3/_linear.py | 33 ++++++ .../models/deepseek_v3/_model_builders.py | 0 torchtune/models/deepseek_v3/_moe.py | 67 +++++++++++ .../deepseek_v3/_position_embeddings.py | 5 + torchtune/modules/moe/moe.py | 1 - 8 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 torchtune/models/deepseek_v3/__init__.py create mode 100644 torchtune/models/deepseek_v3/_attention.py create mode 100644 torchtune/models/deepseek_v3/_component_builders.py create mode 100644 torchtune/models/deepseek_v3/_linear.py create mode 100644 torchtune/models/deepseek_v3/_model_builders.py create mode 100644 torchtune/models/deepseek_v3/_moe.py create mode 100644 torchtune/models/deepseek_v3/_position_embeddings.py diff --git a/torchtune/models/deepseek_v3/__init__.py b/torchtune/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtune/models/deepseek_v3/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py new file mode 100644 index 0000000000..cf4db8f164 --- /dev/null +++ b/torchtune/models/deepseek_v3/_attention.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from typing import Optional +from torchtune.modules.attention_utils import _MaskType +from torchtune.modules import RMSNorm +from torchtune.models.deepseek_v3 import DeepSeekV3LatentLinear + + +class DeepSeekV3Attention(nn.Module): + def __init__(self, + embed_dim: int, + num_heads: int, + qk_rope_head_dim: int, + v_head_dim: int, + qk_nope_head_dim: int, + q_head_dim: int, + q_proj: nn.Module, + kv_proj: DeepSeekV3LatentLinear, + output_proj: nn.Module, + kv_norm: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + # kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0,): + + self.num_heads = num_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.q_head_dim = q_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + # self.kv_cache = kv_cache + self.q_proj = q_proj + self.kv_proj = kv_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.kv_norm = kv_norm + self.pos_embeddings = pos_embeddings + self.softmax_scale = self.q_head_dim ** (-0.5) + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # q is sometimes decomposed into A/B + # kv is *always* decomposed + + # when q is decomposed the norm is applied but + # not otherwise - in this case the norm + # should be applied after q a proj and before q b proj + + # for kv decomposition pos embeddings need to be extracted before + # projecting back up + + b, s_x, _ = x.shape + q = self.q_proj(x) + q = q.view(b, s_x, self.num_heads, self.q_head_dim) + q = q.transpose(1, 2) + + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + kv, k_pe = self.kv_proj(x) + kv = kv.view(b, s_x, self.num_kv_heads, self.qk_nope_head_dim + self.v_head_dim) + kv = kv.transpose(1, 2) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) + k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) + + query_states = q_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + output = self._attention_call( + query_states, + key_states, + value_states, + mask=mask, + dropout_p=self.attn_dropout if self.training else 0.0, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py new file mode 100644 index 0000000000..d462a6e257 --- /dev/null +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn +from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear +from torchtune.models.deepseek_v3._position_embeddings import DeepSeekV3RoPE +from torchtune.modules import ( + FeedForward, + MultiheadAttention, + RMSNorm, + TransformerDecoder, + TransformerDecoderLayer, + Tokenizer +) +from torchtune.modules.moe.experts import GroupedExperts +from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings + +def deepseek_v3_mlp( + dim: int, + hidden_dim: int, + ffn_dropout: float = 0.0 +) -> FeedForward: + """ + Builds the FeedForward layer for DeepSeek V3. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + return FeedForward(gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj, dropout=ffn_dropout, activation_fn=nn.SiLU) + +def deepseek_v3( + *, + vocab_size: int, + embed_dim: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_seq_len: int, + rope_base: int = 10_000, + q_lora_rank: Optional[int] = None, + rope_head_dim: Optional[int] = None, + v_head_dim: Optional[int] = None, +): + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) + layers = [] + for i in range(num_layers): + if q_lora_rank is None: + q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) + else: + q_proj = DeepSeekV3LatentLinear( + in_dim=embed_dim, + out_dim=num_heads * q_head_dim, + rank=q_lora_rank, + ) + kv_proj = DeepSeekV3LatentLinear( + in_dim=embed_dim, + out_dim=num_kv_heads * (q_head_dim - rope_head_dim + v_head_dim), + rank=q_lora_rank, + ) + self_attn = MultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + pos_embeddings=rope, + q_proj=q_proj, + kv_proj=kv_proj, + o_proj=nn.Linear(num_heads * v_head_dim, embed_dim, bias=False), + ) + if i >= first_moe_layer and i % moe_every_n_layers == 0: + mlp_layer = MoE( + experts=GroupedExperts( + dim=embed_dim, hidden_dim=hidden_dim, num_experts=num_experts + ), + router=TokenChoiceTopKRouter( + gate=nn.Linear(embed_dim, num_experts, bias=False), + dim=embed_dim, + num_experts=num_experts, + experts_per_token=experts_per_token, + ), + shared_expert=( + deepseek_v3_mlp(dim=embed_dim, hidden_dim=hidden_dim) if use_shared_expert else None + ) + ) + else: + mlp_layer = deepseek_v3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp_layer, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + layers = nn.ModuleList(layers) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + ) \ No newline at end of file diff --git a/torchtune/models/deepseek_v3/_linear.py b/torchtune/models/deepseek_v3/_linear.py new file mode 100644 index 0000000000..2ce574db04 --- /dev/null +++ b/torchtune/models/deepseek_v3/_linear.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from typing import Optional, Optional +from torchtune.modules import RMSNorm + +class DeepSeekV3LatentLinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + rope_head_dim: Optional[int] = None, + ): + super().__init__() + self.rope_head_dim = rope_head_dim + intermediate_dim = rope_head_dim + rank if rope_head_dim else rank + self.a_proj = nn.Linear( + in_features=in_dim, out_features=intermediate_dim, bias=False + ) + self.b_proj = nn.Linear(in_features=intermediate_dim, out_features=out_dim, bias=False) + self.norm = RMSNorm(rank) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s_x, _ = x.shape + out = self.a_proj(x) + + if self.rope_head_dim: + out, rope_out = torch.split(out, [self.rank, self.rope_head_dim], dim=-1) + rope_out = rope_out.view(b, s_x, 1, self.rope_head_dim).transpose(1, 2) + out = self.b_proj(self.norm(out)) + return out, rope_out + + return out diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py new file mode 100644 index 0000000000..8c6dde745f --- /dev/null +++ b/torchtune/models/deepseek_v3/_moe.py @@ -0,0 +1,67 @@ + +import torch + +class DeepSeekV3MoE(nn.Module): + def __init__(self): + self.experts = experts + self.router = router + self.shared_expert = shared_expert + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, h = x.shape + top_scores, token_idxs, num_tokens_per_expert = self.router(x.reshape(b * s, h)) + token_idxs = token_idxs.reshape(-1, 1).expand(-1, h) + routed_input = torch.gather(x.view(-1, h), dim=0, index=token_idxs) + routed_input = routed_input.reshape(b, s, h) + return self.experts(routed_input) + +class DeepSeekV3TokenChoiceTopKRouter(nn.Module): + def __init__(self): + self.gate = gate # nn.Linear + self.dim = dim + self.num_experts = num_experts + self.experts_per_token = experts_per_token + self.n_groups = n_groups + self.e_score_correction_bias = nn.Parameter(torch.empty((self.experts))) + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n = x.shape[0] + scores = self.gate(x) + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + + scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(n, self.n_groups, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + n, self.n_groups, self.n_routed_experts // self.n_groups + ) + .reshape(n, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + + if self.num_experts > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + + return topk_idx, topk_weight \ No newline at end of file diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index b6fd008356..e188c0422f 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -61,7 +61,6 @@ def forward( scores, k=self.experts_per_token, dim=1 ) self.selected_experts_indices = selected_experts_indices - # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( From b623410ee9926bec044b869408b760b08da8d56e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 8 May 2025 10:57:04 +0100 Subject: [PATCH 02/25] wip2 --- torchtune/models/deepseek_v3/_attention.py | 10 ++++++++-- torchtune/models/deepseek_v3/_linear.py | 8 +++++++- torchtune/models/deepseek_v3/_moe.py | 14 ++++++++++---- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index cf4db8f164..59a9bd61f1 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import torch import torch.nn as nn from typing import Optional @@ -53,7 +59,7 @@ def forward( mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - + # q is sometimes decomposed into A/B # kv is *always* decomposed @@ -61,7 +67,7 @@ def forward( # not otherwise - in this case the norm # should be applied after q a proj and before q b proj - # for kv decomposition pos embeddings need to be extracted before + # for kv decomposition pos embeddings need to be extracted before # projecting back up b, s_x, _ = x.shape diff --git a/torchtune/models/deepseek_v3/_linear.py b/torchtune/models/deepseek_v3/_linear.py index 2ce574db04..8ab3495134 100644 --- a/torchtune/models/deepseek_v3/_linear.py +++ b/torchtune/models/deepseek_v3/_linear.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import torch import torch.nn as nn from typing import Optional, Optional @@ -19,7 +25,7 @@ def __init__( ) self.b_proj = nn.Linear(in_features=intermediate_dim, out_features=out_dim, bias=False) self.norm = RMSNorm(rank) - + def forward(self, x: torch.Tensor) -> torch.Tensor: b, s_x, _ = x.shape out = self.a_proj(x) diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index 8c6dde745f..b3e8171809 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -1,4 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import torch class DeepSeekV3MoE(nn.Module): @@ -6,7 +12,7 @@ def __init__(self): self.experts = experts self.router = router self.shared_expert = shared_expert - + def forward(self, x: torch.Tensor) -> torch.Tensor: b, s, h = x.shape @@ -15,7 +21,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: routed_input = torch.gather(x.view(-1, h), dim=0, index=token_idxs) routed_input = routed_input.reshape(b, s, h) return self.experts(routed_input) - + class DeepSeekV3TokenChoiceTopKRouter(nn.Module): def __init__(self): self.gate = gate # nn.Linear @@ -32,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: n = x.shape[0] scores = self.gate(x) scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) - + scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(n, self.n_groups, -1).topk(2, dim=-1)[0].sum(dim=-1) @@ -64,4 +70,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: topk_weight * self.routed_scaling_factor ) # must multiply the scaling factor - return topk_idx, topk_weight \ No newline at end of file + return topk_idx, topk_weight From c835a06d7153aae31db2b6cd10b72c87d4da2655 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Sun, 11 May 2025 15:28:03 +0100 Subject: [PATCH 03/25] completed first draft --- torchtune/models/deepseek_v3/_attention.py | 1 - .../models/deepseek_v3/_component_builders.py | 107 ++++++++++------- torchtune/models/deepseek_v3/_moe.py | 110 +++++++++++------- torchtune/modules/moe/experts.py | 3 + torchtune/modules/moe/moe.py | 3 +- 5 files changed, 140 insertions(+), 84 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 59a9bd61f1..8af2195942 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -26,7 +26,6 @@ def __init__(self, kv_norm: nn.Module, pos_embeddings: Optional[nn.Module] = None, q_norm: Optional[nn.Module] = None, - # kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0,): diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index d462a6e257..785e0d82ea 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -9,6 +9,8 @@ import torch from torch import nn from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear +from torchtune.models.deepseek_v3._attention import DeepSeekV3Attention +from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter from torchtune.models.deepseek_v3._position_embeddings import DeepSeekV3RoPE from torchtune.modules import ( FeedForward, @@ -16,24 +18,14 @@ RMSNorm, TransformerDecoder, TransformerDecoderLayer, - Tokenizer + Tokenizer, + TransformerSelfAttentionLayer, ) from torchtune.modules.moe.experts import GroupedExperts +from torchtune.modules.moe.moe import MoE from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings -def deepseek_v3_mlp( - dim: int, - hidden_dim: int, - ffn_dropout: float = 0.0 -) -> FeedForward: - """ - Builds the FeedForward layer for DeepSeek V3. - """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - return FeedForward(gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj, dropout=ffn_dropout, activation_fn=nn.SiLU) - + def deepseek_v3( *, vocab_size: int, @@ -47,52 +39,67 @@ def deepseek_v3( q_lora_rank: Optional[int] = None, rope_head_dim: Optional[int] = None, v_head_dim: Optional[int] = None, + moe_every_n_layers: Optional[int] = None, + first_moe_layer: Optional[int] = None, + num_experts: Optional[int] = None, + num_groups: Optional[int] = None, + topk_groups: Optional[int] = None, + norm_topk_prob: Optional[float] = None, + routed_scaling_factor: Optional[float] = None, + experts_per_token: Optional[float] = None, + mlp_hidden_dim: Optional[int] = None, + norm_eps: float = 1e-5, ): + head_dim = embed_dim // num_heads rope = RotaryPositionalEmbeddings( dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) layers = [] for i in range(num_layers): - if q_lora_rank is None: - q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) + if q_lora_rank is not None: + q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False) else: - q_proj = DeepSeekV3LatentLinear( - in_dim=embed_dim, - out_dim=num_heads * q_head_dim, - rank=q_lora_rank, - ) - kv_proj = DeepSeekV3LatentLinear( - in_dim=embed_dim, - out_dim=num_kv_heads * (q_head_dim - rope_head_dim + v_head_dim), - rank=q_lora_rank, - ) - self_attn = MultiheadAttention( + q_proj = DeepSeekV3LatentLinear(embed_dim, num_heads * head_dim, q_lora_rank) + + self_attn = DeepSeekV3Attention( embed_dim=embed_dim, num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - pos_embeddings=rope, + qk_rope_head_dim=head_dim, + v_head_dim=v_head_dim, + qk_nope_head_dim=head_dim, + q_head_dim=head_dim, q_proj=q_proj, - kv_proj=kv_proj, - o_proj=nn.Linear(num_heads * v_head_dim, embed_dim, bias=False), + kv_proj=DeepSeekV3LatentLinear(embed_dim, num_kv_heads * head_dim * 2, kv_lora_rank), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + kv_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + pos_embeddings=rope, + q_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + max_seq_len=max_seq_len, + is_causal=True, + attn_dropout=0.0, ) - if i >= first_moe_layer and i % moe_every_n_layers == 0: + is_moe = (moe_every_n_layers is None or (i + 1) % moe_every_n_layers == 0) and i >= first_moe_layer + if is_moe: mlp_layer = MoE( experts=GroupedExperts( - dim=embed_dim, hidden_dim=hidden_dim, num_experts=num_experts + + num_experts=num_experts, ), - router=TokenChoiceTopKRouter( + router=DeepSeekV3TokenChoiceTopKRouter( gate=nn.Linear(embed_dim, num_experts, bias=False), dim=embed_dim, num_experts=num_experts, experts_per_token=experts_per_token, + num_groups=num_groups, + topk_groups=topk_groups, + norm_topk_prob=norm_topk_prob, + routed_scaling_factor=routed_scaling_factor, ), - shared_expert=( - deepseek_v3_mlp(dim=embed_dim, hidden_dim=hidden_dim) if use_shared_expert else None - ) + shared_expert=deepseek_v3_mlp(embed_dim, mlp_hidden_dim), ) else: - mlp_layer = deepseek_v3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + mlp_layer = deepseek_v3_mlp(embed_dim, mlp_hidden_dim) + layer = TransformerSelfAttentionLayer( attn=self_attn, @@ -101,6 +108,7 @@ def deepseek_v3( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) + layers = nn.ModuleList(layers) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -109,4 +117,23 @@ def deepseek_v3( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, - ) \ No newline at end of file + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(dim=embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def deepseek_v3_mlp( + dim: int, + hidden_dim: int +) -> FeedForward: + """ + Builds the FeedForward layer for DeepSeek V3. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + return FeedForward(gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj) + + diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index b3e8171809..c12a3c2882 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -6,68 +6,94 @@ # LICENSE file in the root directory of this source tree. import torch +from torch import nn -class DeepSeekV3MoE(nn.Module): - def __init__(self): - self.experts = experts - self.router = router - self.shared_expert = shared_expert - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - b, s, h = x.shape - top_scores, token_idxs, num_tokens_per_expert = self.router(x.reshape(b * s, h)) - token_idxs = token_idxs.reshape(-1, 1).expand(-1, h) - routed_input = torch.gather(x.view(-1, h), dim=0, index=token_idxs) - routed_input = routed_input.reshape(b, s, h) - return self.experts(routed_input) class DeepSeekV3TokenChoiceTopKRouter(nn.Module): - def __init__(self): - self.gate = gate # nn.Linear + def __init__(self, + gate: nn.Module, + dim: int, + num_experts: int, + experts_per_token: int, + num_groups: int, + topk_groups: int, + norm_topk_prob: bool, + routed_scaling_factor: float + ): + self.gate = gate self.dim = dim self.num_experts = num_experts self.experts_per_token = experts_per_token - self.n_groups = n_groups - self.e_score_correction_bias = nn.Parameter(torch.empty((self.experts))) - self.topk_group = topk_group + self.num_groups = num_groups + self.topk_groups = topk_groups self.norm_topk_prob = norm_topk_prob self.routed_scaling_factor = routed_scaling_factor + self.e_score_correction_bias = nn.Parameter(torch.rand((self.num_experts))) def forward(self, x: torch.Tensor) -> torch.Tensor: n = x.shape[0] - scores = self.gate(x) - scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + logits = self.gate(x) - scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + # calculate scores for every expert in every group + scores = torch.sigmoid(logits.to(torch.float32)).to(x.dtype) + scores += self.e_score_correction_bias.unsqueeze(0) + + # now calculate scores for every group based on the + # top 2 scores of experts within each group + experts_per_group = self.num_experts // self.num_groups group_scores = ( - scores_for_choice.view(n, self.n_groups, -1).topk(2, dim=-1)[0].sum(dim=-1) + scores.view(n, self.num_groups, experts_per_group) + .topk(2, dim=-1)[0].sum(dim=-1) ) - group_idx = torch.topk( - group_scores, k=self.topk_group, dim=-1, sorted=False - )[ + + # grab the topk_groups number of groups based + # on the scores for each group calculated above + group_idxs = torch.topk( + group_scores, k=self.topk_groups, dim=-1, sorted=False)[ 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] + ] + + # mask out all experts within groups which will not be considered + group_mask = torch.zeros_like(group_scores, dtype=torch.bool) + group_mask.scatter_(1, group_idxs, True) # [n, n_group] + score_mask = ( group_mask.unsqueeze(-1) .expand( - n, self.n_groups, self.n_routed_experts // self.n_groups + n, self.n_groups, experts_per_group ) .reshape(n, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill( - ~score_mask.bool(), 0.0 - ) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) + ) + masked_scores = scores.masked_fill( + ~score_mask, float('-inf') + ) + + # now select the top experts_per_token number of + # experts based on experts within eligible groups + _, selected_experts_idxs = torch.topk(masked_scores, k=self.experts_per_token, dim=-1, sorted=False) + scores_per_token = scores.gather(1, selected_experts_idxs) + + # normalize scores if self.num_experts > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = ( - topk_weight * self.routed_scaling_factor - ) # must multiply the scaling factor + denominator = scores_per_token.sum(dim=-1, keepdim=True) + 1e-20 + scores_per_token /= denominator + + # apply scaling factor + scores_per_token = ( + scores_per_token * self.routed_scaling_factor + ) - return topk_idx, topk_weight + num_tokens_per_expert = torch.histc( + selected_experts_idxs.float(), bins=self.num_experts, min=0, max=self.num_experts - 1 + ).to(torch.int32) + + token_idxs_experts_sorted = torch.argsort( + selected_experts_idxs.view(-1), stable=True + ) + + scores_per_expert = scores_per_token.view(-1)[token_idxs_experts_sorted] + token_idxs_experts_sorted = ( + token_idxs_experts_sorted // self.experts_per_token + ) + return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index c667662e06..e4838d961b 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -65,6 +65,9 @@ def forward( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): + if x_expert.numel() == 0: + out_experts_splits.append(torch.zeros_like(x_expert)) + continue w1, w2, w3 = ( self.gate_proj[expert_idx], self.down_proj[expert_idx], diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index e188c0422f..8e606be0f3 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -140,6 +140,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.shared_expert(x).reshape(bs * slen, dim) else: out = torch.zeros_like(x.reshape(bs * slen, dim)) - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + if routed_output.numel() > 0: + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out From 0e25699cb14fd74ba7cebad721c06edf677cec43 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 13 May 2025 12:39:34 +0100 Subject: [PATCH 04/25] adding weight conversion --- torchtune/models/deepseek_v3/_attention.py | 13 +- .../models/deepseek_v3/_component_builders.py | 60 +++--- .../models/deepseek_v3/_convert_weights.py | 193 ++++++++++++++++++ torchtune/models/deepseek_v3/_linear.py | 20 +- torchtune/models/deepseek_v3/_moe.py | 2 +- torchtune/models/llama4/_convert_weights.py | 2 +- torchtune/modules/moe/moe.py | 2 +- 7 files changed, 247 insertions(+), 45 deletions(-) create mode 100644 torchtune/models/deepseek_v3/_convert_weights.py diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 8af2195942..2b35f197f7 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -9,7 +9,6 @@ from typing import Optional from torchtune.modules.attention_utils import _MaskType from torchtune.modules import RMSNorm -from torchtune.models.deepseek_v3 import DeepSeekV3LatentLinear class DeepSeekV3Attention(nn.Module): @@ -21,15 +20,13 @@ def __init__(self, qk_nope_head_dim: int, q_head_dim: int, q_proj: nn.Module, - kv_proj: DeepSeekV3LatentLinear, + kv_proj: nn.Module, output_proj: nn.Module, - kv_norm: nn.Module, pos_embeddings: Optional[nn.Module] = None, - q_norm: Optional[nn.Module] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0,): - + super().__init__() self.num_heads = num_heads self.embed_dim = embed_dim self.attn_dropout = attn_dropout @@ -45,11 +42,11 @@ def __init__(self, self.q_proj = q_proj self.kv_proj = kv_proj self.output_proj = output_proj - self.q_norm = q_norm - self.kv_norm = kv_norm self.pos_embeddings = pos_embeddings self.softmax_scale = self.q_head_dim ** (-0.5) - + self.cache_enabled = False + + def forward( self, x: torch.Tensor, diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 785e0d82ea..c8fcc1c0c4 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -11,14 +11,10 @@ from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear from torchtune.models.deepseek_v3._attention import DeepSeekV3Attention from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter -from torchtune.models.deepseek_v3._position_embeddings import DeepSeekV3RoPE from torchtune.modules import ( FeedForward, - MultiheadAttention, RMSNorm, TransformerDecoder, - TransformerDecoderLayer, - Tokenizer, TransformerSelfAttentionLayer, ) from torchtune.modules.moe.experts import GroupedExperts @@ -32,12 +28,12 @@ def deepseek_v3( embed_dim: int, num_layers: int, num_heads: int, - num_kv_heads: int, - head_dim: int, max_seq_len: int, rope_base: int = 10_000, q_lora_rank: Optional[int] = None, - rope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + kv_lora_rank: Optional[int] = None, v_head_dim: Optional[int] = None, moe_every_n_layers: Optional[int] = None, first_moe_layer: Optional[int] = None, @@ -48,19 +44,33 @@ def deepseek_v3( routed_scaling_factor: Optional[float] = None, experts_per_token: Optional[float] = None, mlp_hidden_dim: Optional[int] = None, + moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, ): head_dim = embed_dim // num_heads - rope = RotaryPositionalEmbeddings( - dim=head_dim, max_seq_len=max_seq_len, base=rope_base - ) + rope = nn.Identity() layers = [] for i in range(num_layers): - if q_lora_rank is not None: - q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False) - else: - q_proj = DeepSeekV3LatentLinear(embed_dim, num_heads * head_dim, q_lora_rank) + # q is sometimes decomposed into A/B (if q_lora_rank) + # kv is *always* decomposed + + # when q is decomposed the norm is applied but + # not otherwise - in this case the norm + # should be applied after q a proj and before q b proj + + # for kv decomposition pos embeddings need to be extracted before + # projecting back up + q_head_dim = qk_rope_head_dim + qk_nope_head_dim + if q_lora_rank is None: + q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) + else: + q_proj = DeepSeekV3LatentLinear( + in_dim=embed_dim, + out_dim=num_heads * q_head_dim, + rank=q_lora_rank, + norm=RMSNorm(dim=q_lora_rank), + ) self_attn = DeepSeekV3Attention( embed_dim=embed_dim, num_heads=num_heads, @@ -69,11 +79,13 @@ def deepseek_v3( qk_nope_head_dim=head_dim, q_head_dim=head_dim, q_proj=q_proj, - kv_proj=DeepSeekV3LatentLinear(embed_dim, num_kv_heads * head_dim * 2, kv_lora_rank), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - kv_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + kv_proj=DeepSeekV3LatentLinear(in_dim=embed_dim, + out_dim=num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim), + rank=kv_lora_rank, + norm=RMSNorm(dim=kv_lora_rank), + rope_head_dim=qk_rope_head_dim), + output_proj=nn.Linear(num_heads * v_head_dim, embed_dim, bias=False), pos_embeddings=rope, - q_norm=RMSNorm(dim=embed_dim, eps=norm_eps), max_seq_len=max_seq_len, is_causal=True, attn_dropout=0.0, @@ -82,7 +94,8 @@ def deepseek_v3( if is_moe: mlp_layer = MoE( experts=GroupedExperts( - + dim=embed_dim, + hidden_dim=moe_hidden_dim, num_experts=num_experts, ), router=DeepSeekV3TokenChoiceTopKRouter( @@ -95,12 +108,11 @@ def deepseek_v3( norm_topk_prob=norm_topk_prob, routed_scaling_factor=routed_scaling_factor, ), - shared_expert=deepseek_v3_mlp(embed_dim, mlp_hidden_dim), + shared_expert=deepseek_v3_mlp(embed_dim, moe_hidden_dim), ) else: mlp_layer = deepseek_v3_mlp(embed_dim, mlp_hidden_dim) - layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp_layer, @@ -122,8 +134,8 @@ def deepseek_v3( norm=RMSNorm(dim=embed_dim, eps=norm_eps), output=output_proj, ) - - + + def deepseek_v3_mlp( dim: int, hidden_dim: int @@ -135,5 +147,3 @@ def deepseek_v3_mlp( up_proj = nn.Linear(dim, hidden_dim, bias=False) down_proj = nn.Linear(hidden_dim, dim, bias=False) return FeedForward(gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj) - - diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py new file mode 100644 index 0000000000..17db691966 --- /dev/null +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -0,0 +1,193 @@ +from collections import defaultdict +import torch +from torchtune.models.convert_weights import get_mapped_key +import regex as re +# hf_model +# DeepseekV3ForCausalLM( +# (model): DeepseekV3Model( +# (embed_tokens): Identity() +# (layers): ModuleList( +# (0): DeepseekV3DecoderLayer( +# (self_attn): DeepseekV3Attention( +# (q_a_proj): Linear(in_features=16, out_features=16, bias=False) +# (q_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (q_b_proj): Linear(in_features=16, out_features=64, bias=False) +# (kv_a_proj_with_mqa): Linear(in_features=16, out_features=32, bias=False) +# (kv_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (kv_b_proj): Linear(in_features=16, out_features=64, bias=False) +# (o_proj): Linear(in_features=32, out_features=16, bias=False) +# ) +# (mlp): DeepseekV3MLP( +# (gate_proj): Linear(in_features=16, out_features=32, bias=False) +# (up_proj): Linear(in_features=16, out_features=32, bias=False) +# (down_proj): Linear(in_features=32, out_features=16, bias=False) +# (act_fn): SiLU() +# ) +# (input_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (post_attention_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# ) +# (1): DeepseekV3DecoderLayer( +# (self_attn): DeepseekV3Attention( +# (q_a_proj): Linear(in_features=16, out_features=16, bias=False) +# (q_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (q_b_proj): Linear(in_features=16, out_features=64, bias=False) +# (kv_a_proj_with_mqa): Linear(in_features=16, out_features=32, bias=False) +# (kv_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (kv_b_proj): Linear(in_features=16, out_features=64, bias=False) +# (o_proj): Linear(in_features=32, out_features=16, bias=False) +# ) +# (mlp): DeepseekV3MoE( +# (experts): ModuleList( +# (0-255): 256 x DeepseekV3MLP( +# (gate_proj): Linear(in_features=16, out_features=16, bias=False) +# (up_proj): Linear(in_features=16, out_features=16, bias=False) +# (down_proj): Linear(in_features=16, out_features=16, bias=False) +# (act_fn): SiLU() +# ) +# ) +# (gate): DeepseekV3TopkRouter() +# (shared_experts): DeepseekV3MLP( +# (gate_proj): Linear(in_features=16, out_features=16, bias=False) +# (up_proj): Linear(in_features=16, out_features=16, bias=False) +# (down_proj): Linear(in_features=16, out_features=16, bias=False) +# (act_fn): SiLU() +# ) +# ) +# (input_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (post_attention_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) +# ) +# ) +# (norm): DeepseekV3RMSNorm((16,), eps=1e-06) +# (rotary_emb): DeepseekV3RotaryEmbedding() +# ) +# (lm_head): Linear(in_features=16, out_features=129280, bias=False) +# ) +# TransformerDecoder( +# (tok_embeddings): Identity() +# (layers): ModuleList( +# (0): TransformerSelfAttentionLayer( +# (attn): DeepSeekV3Attention( +# (q_proj): DeepSeekV3LatentLinear( +# (a): Linear(in_features=16, out_features=16, bias=False) +# (b): Linear(in_features=16, out_features=64, bias=False) +# (norm): RMSNorm() +# ) +# (kv_proj): DeepSeekV3LatentLinear( +# (a): Linear(in_features=16, out_features=32, bias=False) +# (b): Linear(in_features=16, out_features=64, bias=False) +# (norm): RMSNorm() +# ) +# (output_proj): Linear(in_features=32, out_features=16, bias=False) +# (pos_embeddings): Identity() +# ) +# (mlp): FeedForward( +# (w1): Linear(in_features=16, out_features=32, bias=False) +# (w2): Linear(in_features=32, out_features=16, bias=False) +# (w3): Linear(in_features=16, out_features=32, bias=False) +# (activation): SiLU() +# ) +# (sa_norm): RMSNorm() +# (mlp_norm): RMSNorm() +# (sa_scale): Identity() +# (mlp_scale): Identity() +# ) +# (1): TransformerSelfAttentionLayer( +# (attn): DeepSeekV3Attention( +# (q_proj): DeepSeekV3LatentLinear( +# (a): Linear(in_features=16, out_features=16, bias=False) +# (b): Linear(in_features=16, out_features=64, bias=False) +# (norm): RMSNorm() +# ) +# (kv_proj): DeepSeekV3LatentLinear( +# (a): Linear(in_features=16, out_features=32, bias=False) +# (b): Linear(in_features=16, out_features=64, bias=False) +# (norm): RMSNorm() +# ) +# (output_proj): Linear(in_features=32, out_features=16, bias=False) +# (pos_embeddings): Identity() +# ) +# (mlp): MoE( +# (experts): GroupedExperts() +# (router): DeepSeekV3TokenChoiceTopKRouter( +# (gate): Linear(in_features=16, out_features=256, bias=False) +# ) +# (shared_expert): FeedForward( +# (w1): Linear(in_features=16, out_features=16, bias=False) +# (w2): Linear(in_features=16, out_features=16, bias=False) +# (w3): Linear(in_features=16, out_features=16, bias=False) +# (activation): SiLU() +# ) +# ) +# (sa_norm): RMSNorm() +# (mlp_norm): RMSNorm() +# (sa_scale): Identity() +# (mlp_scale): Identity() +# ) +# ) +# (norm): RMSNorm() +# (output): Linear(in_features=16, out_features=129280, bias=False) +# ) + + +# state dict key mappings from HF's format to torchtune's format for DeepSeek V3 +# Note: Conversion might require custom logic beyond simple key mapping, +# especially for kv_proj and MoE expert weights. +_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + # attenion weights + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attn.q_proj.a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attn.q_proj.norm.scale", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attn.q_proj.b.weight", + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attn.kv_proj.a.weight", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attn.kv_proj.norm.scale", + "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attn.kv_proj.b.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + + # mlp (non-expert weights) + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + + # mlp MoE shared expert weights + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.mlp.shared_expert.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.mlp.shared_expert.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.mlp.shared_expert.w2.weight", + + # mlp MoE token router weights + "model.layers.{}.mlp.gate.weight": "layers.{}.mlp.router.gate.weight", + "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.mlp.router.e_score_correction_bias", + + "lm_head.weight": "output.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, +} + +def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + # first merge expert weights + expert_weights_grouped = defaultdict(lambda: defaultdict(list)) + expert_keys_processed = set() + for key, value in state_dict.items(): + expert_match = re.match(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", key) + if expert_match: + layer_idx = expert_match.group(1) + expert_idx = int(expert_match.group(2)) + proj_name_part = expert_match.group(3) + expert_weights_grouped[layer_idx][proj_name_part].append((expert_idx, value)) + expert_keys_processed.add(key) + + for layer_idx, projections in expert_weights_grouped.items(): + for proj_type, weights_list in projections.items(): + weights_list.sort(key=lambda x: x[0]) + stacked_weights = torch.stack([w[1] for w in weights_list], dim=0) + new_key = f"layers.{layer_idx}.mlp.experts.{proj_type}" + converted_state_dict[new_key] = stacked_weights + + for key, value in state_dict.items(): + if key not in expert_keys_processed and "rotary_emb.inv_freq" not in key: + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict diff --git a/torchtune/models/deepseek_v3/_linear.py b/torchtune/models/deepseek_v3/_linear.py index 8ab3495134..b9b07414cb 100644 --- a/torchtune/models/deepseek_v3/_linear.py +++ b/torchtune/models/deepseek_v3/_linear.py @@ -9,31 +9,33 @@ from typing import Optional, Optional from torchtune.modules import RMSNorm + class DeepSeekV3LatentLinear(nn.Module): def __init__( self, + *, in_dim: int, out_dim: int, rank: int, + norm: nn.Module, rope_head_dim: Optional[int] = None, ): super().__init__() - self.rope_head_dim = rope_head_dim - intermediate_dim = rope_head_dim + rank if rope_head_dim else rank - self.a_proj = nn.Linear( - in_features=in_dim, out_features=intermediate_dim, bias=False + self.rope_head_dim = rope_head_dim or 0 + self.a = nn.Linear( + in_features=in_dim, out_features=rank + self.rope_head_dim, bias=False ) - self.b_proj = nn.Linear(in_features=intermediate_dim, out_features=out_dim, bias=False) - self.norm = RMSNorm(rank) + self.b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.norm = norm def forward(self, x: torch.Tensor) -> torch.Tensor: b, s_x, _ = x.shape - out = self.a_proj(x) + out = self.a(x) if self.rope_head_dim: out, rope_out = torch.split(out, [self.rank, self.rope_head_dim], dim=-1) rope_out = rope_out.view(b, s_x, 1, self.rope_head_dim).transpose(1, 2) - out = self.b_proj(self.norm(out)) + out = self.b(self.norm(out)) return out, rope_out - return out + return self.b(self.norm(out)) diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index c12a3c2882..ccfefef079 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -1,4 +1,3 @@ - # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -20,6 +19,7 @@ def __init__(self, norm_topk_prob: bool, routed_scaling_factor: float ): + super().__init__() self.gate = gate self.dim = dim self.num_experts = num_experts diff --git a/torchtune/models/llama4/_convert_weights.py b/torchtune/models/llama4/_convert_weights.py index 9d48c13d9e..e640d751e0 100644 --- a/torchtune/models/llama4/_convert_weights.py +++ b/torchtune/models/llama4/_convert_weights.py @@ -235,7 +235,7 @@ def llama4_tune_to_hf( # Combine gate projection with up projection new_key = get_mapped_key(key, inverted_mapping_dict) up_proj = state_dict[key.replace("gate", "up")] - converted_state_dict[new_key] = torch.cat([value, up_proj], dim=-1) + converted_state_dict[new_key] = torch.cat([value, up_proj], dim=-1 ) continue elif key.endswith("experts.up_proj"): # Skip as already handled with gate projection diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index 8e606be0f3..73372dd0e5 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -141,6 +141,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: out = torch.zeros_like(x.reshape(bs * slen, dim)) if routed_output.numel() > 0: - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out.scatter_add_(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out From 6bd4326be39f526ac4f1327c459dddca610f001b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 13 May 2025 15:22:59 +0100 Subject: [PATCH 05/25] adding checkpointing --- torchtune/models/deepseek_v3/_attention.py | 12 ++++++++---- .../models/deepseek_v3/_component_builders.py | 15 ++++++++------- torchtune/models/deepseek_v3/_convert_weights.py | 14 +++++++++----- torchtune/models/deepseek_v3/_linear.py | 1 + torchtune/models/deepseek_v3/_moe.py | 4 +++- torchtune/training/checkpointing/_checkpointer.py | 12 ++++++++++++ torchtune/training/checkpointing/_utils.py | 1 + 7 files changed, 42 insertions(+), 17 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 2b35f197f7..d22f467a46 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -9,6 +9,7 @@ from typing import Optional from torchtune.modules.attention_utils import _MaskType from torchtune.modules import RMSNorm +from torchtune.modules.attention import _sdpa_or_flex_attention class DeepSeekV3Attention(nn.Module): @@ -22,7 +23,7 @@ def __init__(self, q_proj: nn.Module, kv_proj: nn.Module, output_proj: nn.Module, - pos_embeddings: Optional[nn.Module] = None, + pos_embeddings: nn.Module, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0,): @@ -45,8 +46,9 @@ def __init__(self, self.pos_embeddings = pos_embeddings self.softmax_scale = self.q_head_dim ** (-0.5) self.cache_enabled = False + + self._attention_call = _sdpa_or_flex_attention() - def forward( self, x: torch.Tensor, @@ -55,6 +57,8 @@ def forward( mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + # import ipdb; ipdb.set_trace() # q is sometimes decomposed into A/B # kv is *always* decomposed @@ -76,7 +80,7 @@ def forward( ) kv, k_pe = self.kv_proj(x) - kv = kv.view(b, s_x, self.num_kv_heads, self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(b, s_x, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) kv = kv.transpose(1, 2) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -98,7 +102,7 @@ def forward( value_states, mask=mask, dropout_p=self.attn_dropout if self.training else 0.0, - is_causal=self.kv_cache is None and mask is None and self.is_causal, + is_causal=mask is None, ) # reshape the output to be the same shape as the input diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index c8fcc1c0c4..109d175c7c 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -38,6 +38,7 @@ def deepseek_v3( moe_every_n_layers: Optional[int] = None, first_moe_layer: Optional[int] = None, num_experts: Optional[int] = None, + num_shared_experts: Optional[int] = None, num_groups: Optional[int] = None, topk_groups: Optional[int] = None, norm_topk_prob: Optional[float] = None, @@ -47,8 +48,8 @@ def deepseek_v3( moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, ): - head_dim = embed_dim // num_heads - rope = nn.Identity() + def rope(x, input_pos=None): + return x layers = [] for i in range(num_layers): @@ -74,10 +75,10 @@ def deepseek_v3( self_attn = DeepSeekV3Attention( embed_dim=embed_dim, num_heads=num_heads, - qk_rope_head_dim=head_dim, + qk_rope_head_dim=qk_rope_head_dim, v_head_dim=v_head_dim, - qk_nope_head_dim=head_dim, - q_head_dim=head_dim, + qk_nope_head_dim=qk_nope_head_dim, + q_head_dim=q_head_dim, q_proj=q_proj, kv_proj=DeepSeekV3LatentLinear(in_dim=embed_dim, out_dim=num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim), @@ -108,7 +109,7 @@ def deepseek_v3( norm_topk_prob=norm_topk_prob, routed_scaling_factor=routed_scaling_factor, ), - shared_expert=deepseek_v3_mlp(embed_dim, moe_hidden_dim), + shared_expert=deepseek_v3_mlp(embed_dim, moe_hidden_dim * num_shared_experts), ) else: mlp_layer = deepseek_v3_mlp(embed_dim, mlp_hidden_dim) @@ -130,7 +131,7 @@ def deepseek_v3( layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, - head_dim=head_dim, + head_dim=embed_dim // num_heads, norm=RMSNorm(dim=embed_dim, eps=norm_eps), output=output_proj, ) diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py index 17db691966..33886716e8 100644 --- a/torchtune/models/deepseek_v3/_convert_weights.py +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -136,7 +136,7 @@ "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", - "model.norm.weight": "norm.scale", + "model.norm.weight": "norm.scale", # attenion weights "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attn.q_proj.a.weight", @@ -147,11 +147,11 @@ "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attn.kv_proj.b.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", - # mlp (non-expert weights) + # mlp non-expert weights "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", - + # mlp MoE shared expert weights "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.mlp.shared_expert.w1.weight", "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.mlp.shared_expert.w3.weight", @@ -165,24 +165,28 @@ "model.layers.{}.self_attn.rotary_emb.inv_freq": None, } + def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: converted_state_dict = {} + # first merge expert weights expert_weights_grouped = defaultdict(lambda: defaultdict(list)) expert_keys_processed = set() for key, value in state_dict.items(): - expert_match = re.match(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", key) + expert_match = re.match( + r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", key) if expert_match: layer_idx = expert_match.group(1) expert_idx = int(expert_match.group(2)) proj_name_part = expert_match.group(3) + expert_weights_grouped[layer_idx][proj_name_part].append((expert_idx, value)) expert_keys_processed.add(key) for layer_idx, projections in expert_weights_grouped.items(): for proj_type, weights_list in projections.items(): weights_list.sort(key=lambda x: x[0]) - stacked_weights = torch.stack([w[1] for w in weights_list], dim=0) + stacked_weights = torch.stack([w[1].transpose(0, 1) for w in weights_list], dim=0) new_key = f"layers.{layer_idx}.mlp.experts.{proj_type}" converted_state_dict[new_key] = stacked_weights diff --git a/torchtune/models/deepseek_v3/_linear.py b/torchtune/models/deepseek_v3/_linear.py index b9b07414cb..df242ee04f 100644 --- a/torchtune/models/deepseek_v3/_linear.py +++ b/torchtune/models/deepseek_v3/_linear.py @@ -21,6 +21,7 @@ def __init__( rope_head_dim: Optional[int] = None, ): super().__init__() + self.rank = rank self.rope_head_dim = rope_head_dim or 0 self.a = nn.Linear( in_features=in_dim, out_features=rank + self.rope_head_dim, bias=False diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index ccfefef079..611fb045fd 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -35,6 +35,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: logits = self.gate(x) # calculate scores for every expert in every group + # import ipdb; ipdb.set_trace() scores = torch.sigmoid(logits.to(torch.float32)).to(x.dtype) scores += self.e_score_correction_bias.unsqueeze(0) @@ -60,7 +61,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: score_mask = ( group_mask.unsqueeze(-1) .expand( - n, self.n_groups, experts_per_group + n, self.num_groups, experts_per_group ) .reshape(n, -1) ) @@ -96,4 +97,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: token_idxs_experts_sorted = ( token_idxs_experts_sorted // self.experts_per_token ) + print(scores_per_expert.isnan().any(), token_idxs_experts_sorted.isnan().any(), num_tokens_per_expert.isnan().any()) return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index bef928448a..9394b20a15 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -663,6 +663,12 @@ def load_checkpoint(self) -> Dict[str, Any]: converted_state_dict[training.MODEL_KEY] = llama4_hf_to_tune( merged_state_dict, ) + elif self._model_type == ModelType.DEEPSEEK_V3: + from torchtune.models.deepseek_v3._convert_weights import deepseek_v3_hf_to_tune + + converted_state_dict[training.MODEL_KEY] = deepseek_v3_hf_to_tune( + merged_state_dict, + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -774,6 +780,12 @@ def save_checkpoint( state_dict[training.MODEL_KEY] = llama4_tune_to_hf( state_dict[training.MODEL_KEY], ) + elif self._model_type == ModelType.DEEPSEEK_V3: + from torchtune.models.deepseek_v3._convert_weights import deepseek_v3_tune_to_hf + + state_dict[training.MODEL_KEY] = deepseek_v3_tune_to_hf( + state_dict[training.MODEL_KEY], + ) else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 61113401d7..4324053852 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -109,6 +109,7 @@ class ModelType(Enum): >>> state_dict = my_custom_state_dict_mapping(state_dict) """ + DEEPSEEK_V3: str = "deepseek_v3" GEMMA: str = "gemma" GEMMA2: str = "gemma2" LLAMA2: str = "llama2" From 87c998cb706981261a12c7994bf33c40dcfb5f69 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Sun, 18 May 2025 11:45:58 +0100 Subject: [PATCH 06/25] parity achieved --- torchtune/models/deepseek_v3/_attention.py | 7 +- .../models/deepseek_v3/_component_builders.py | 5 +- .../models/deepseek_v3/_model_builders.py | 5 + torchtune/models/deepseek_v3/_moe.py | 76 ++++++++++- .../deepseek_v3/_position_embeddings.py | 118 ++++++++++++++++++ torchtune/modules/attention_utils.py | 8 +- torchtune/modules/moe/experts.py | 1 + torchtune/modules/moe/moe.py | 21 ++-- 8 files changed, 221 insertions(+), 20 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index d22f467a46..089f062da7 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -83,12 +83,13 @@ def forward( kv = kv.view(b, s_x, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) kv = kv.transpose(1, 2) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) - query_states = q_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) + query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -96,6 +97,7 @@ def forward( key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + output = self._attention_call( query_states, key_states, @@ -103,8 +105,9 @@ def forward( mask=mask, dropout_p=self.attn_dropout if self.training else 0.0, is_causal=mask is None, + scale=self.softmax_scale, ) - # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 109d175c7c..a949584712 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -10,7 +10,7 @@ from torch import nn from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear from torchtune.models.deepseek_v3._attention import DeepSeekV3Attention -from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter +from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter, DeepseekV3MoE from torchtune.modules import ( FeedForward, RMSNorm, @@ -18,7 +18,6 @@ TransformerSelfAttentionLayer, ) from torchtune.modules.moe.experts import GroupedExperts -from torchtune.modules.moe.moe import MoE from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings @@ -93,7 +92,7 @@ def rope(x, input_pos=None): ) is_moe = (moe_every_n_layers is None or (i + 1) % moe_every_n_layers == 0) and i >= first_moe_layer if is_moe: - mlp_layer = MoE( + mlp_layer = DeepseekV3MoE( experts=GroupedExperts( dim=embed_dim, hidden_dim=moe_hidden_dim, diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index e69de29bb2..04ddc179ad 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -0,0 +1,5 @@ + +# def deepseek_v3_671b_256e( + +# ) -> TransformerDecoder: +# pass diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index 611fb045fd..c5e4b6e7e7 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -6,6 +6,76 @@ import torch from torch import nn +from typing import Optional + +class DeepseekV3MoE(nn.Module): + """This class implements the moe layer which is Mixture of Experts. Mixture of Experts + typically consists of a set of expert networks, alongside with a router, which directs input tokens + to the appropriate experts. See more details in https://arxiv.org/2401.0606. + + This class is identical to :class:`~torchtune.modules.moe.moe.MoE`, except that it applies the + router weighting scores to the *output* of the experts, rather than the input. + + Args: + experts (nn.Module): experts module. + router (nn.Module): router module. + shared_expert (Optional[nn.Module]): shared expert module. Default is None. + """ + + def __init__( + self, + *, + experts: nn.Module, + router: nn.Module, + shared_expert: Optional[nn.Module] = None, + ): + super().__init__() + self.experts = experts + self.router = router + self.shared_expert = shared_expert + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + b, s, dim = x.shape + # top_scores and selected_indices shape (bs*slen*experts_per_token,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(b * s, dim)) + + # shape (b*s*experts_per_token, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (b*s*experts_per_token, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (b*s*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = routed_output * top_scores.reshape(-1, 1) + + # import ipdb; ipdb.set_trace() + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x).reshape(b * s, dim) + else: + out = torch.zeros_like(x.reshape(b * s, dim)) + if routed_output.numel() > 0: + out.scatter_add_(dim=0, index=token_indices, src=routed_output) + out = out.reshape(b, s, dim) + return out class DeepSeekV3TokenChoiceTopKRouter(nn.Module): @@ -33,7 +103,6 @@ def __init__(self, def forward(self, x: torch.Tensor) -> torch.Tensor: n = x.shape[0] logits = self.gate(x) - # calculate scores for every expert in every group # import ipdb; ipdb.set_trace() scores = torch.sigmoid(logits.to(torch.float32)).to(x.dtype) @@ -81,9 +150,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scores_per_token /= denominator # apply scaling factor - scores_per_token = ( - scores_per_token * self.routed_scaling_factor - ) + scores_per_token *= self.routed_scaling_factor num_tokens_per_expert = torch.histc( selected_experts_idxs.float(), bins=self.num_experts, min=0, max=self.num_experts - 1 @@ -97,5 +164,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: token_idxs_experts_sorted = ( token_idxs_experts_sorted // self.experts_per_token ) - print(scores_per_expert.isnan().any(), token_idxs_experts_sorted.isnan().any(), num_tokens_per_expert.isnan().any()) return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index 2e41cd717f..e2b64e5265 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -3,3 +3,121 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +from torch import nn + + +class YaRNRotaryPositionalEmbeddings(nn.Module): + """ + This class implements Rotary Positional Embeddings (RoPE) + proposed in https://arxiv.org/abs/2104.09864. + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 + + In this implementation we cache the embeddings for each position upto + ``max_seq_len`` by computing this during init. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim // num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.rope_init() + + def rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) + diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 9d456ca3ea..df55642076 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -53,8 +53,9 @@ def compile_friendly_flex_attention( k: torch.Tensor, v: torch.Tensor, block_mask: BlockMask, + scale: float, ) -> torch.Tensor: - return flex_attention_compiled(q, k, v, block_mask=block_mask) + return flex_attention_compiled(q, k, v, block_mask=block_mask, scale=scale) _MaskType = Union[torch.Tensor, BlockMask] else: @@ -201,6 +202,7 @@ def _attention_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # Flex attention uses the BlockMask @@ -224,6 +226,7 @@ def _attention_call( k, v, block_mask=mask, + scale=scale, ) # If mask is a standard boolean tensor or None, then use SDPA else: @@ -239,6 +242,7 @@ def _attention_call( attn_mask=mask, dropout_p=dropout_p, is_causal=is_causal, + scale=scale, ) else: @@ -250,6 +254,7 @@ def _attention_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # shape: [b, 1, s, s] if mask is not None: @@ -263,6 +268,7 @@ def _attention_call( attn_mask=mask, dropout_p=dropout_p, is_causal=is_causal, + scale=scale, ) return _attention_call diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index e4838d961b..02f89dbc6b 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -56,6 +56,7 @@ def forward( torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) """ + # import ipdb; ipdb.set_trace() # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) x = torch.split( diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index 73372dd0e5..f1423bf764 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -85,6 +85,7 @@ class MoE(nn.Module): """This class implements the moe layer which is Mixture of Experts. Mixture of Experts typically consists of a set of expert networks, alongside with a router, which directs input tokens to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. + Args: experts (nn.Module): experts module. @@ -112,35 +113,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - bs, slen, dim = x.shape + b, s, dim = x.shape # top_scores and selected_indices shape (bs*slen*experts_per_token,) # num_tokens_per_expert shape (num_experts,) ( top_scores, token_indices, num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim)) + ) = self.router(x.reshape(b * s, dim)) - # shape (bs*slen*experts_per_token, dim) + # shape (b*s*experts_per_token, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - # shape (bs*slen*experts_per_token, dim) + # shape (b*s*experts_per_token, dim) routed_input = torch.gather( x.view(-1, dim), dim=0, index=token_indices, ) - routed_input = routed_input * top_scores.reshape(-1, 1) + # routed_input = routed_input * top_scores.reshape(-1, 1) - # shape (bs*slen*top_k, dim) + # shape (b*s*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + routed_output = routed_output * top_scores.reshape(-1, 1) + # import ipdb; ipdb.set_trace() # shared expert if self.shared_expert is not None: - out = self.shared_expert(x).reshape(bs * slen, dim) + out = self.shared_expert(x).reshape(b * s, dim) else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = torch.zeros_like(x.reshape(b * s, dim)) if routed_output.numel() > 0: out.scatter_add_(dim=0, index=token_indices, src=routed_output) - out = out.reshape(bs, slen, dim) + out = out.reshape(b, s, dim) return out From 263bf4ecf989a62a71e923506285815becf59252 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 09:45:08 +0100 Subject: [PATCH 07/25] adding base rope --- torchtune/datasets/_instruct.py | 2 +- torchtune/models/deepseek_v3/_attention.py | 8 +- .../models/deepseek_v3/_component_builders.py | 19 +- torchtune/models/deepseek_v3/_moe.py | 7 +- .../deepseek_v3/_position_embeddings.py | 425 ++++++++++++++---- .../models/qwen2/_positional_embeddings.py | 2 +- 6 files changed, 346 insertions(+), 117 deletions(-) diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index 20168aac1d..0ad244667b 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -116,7 +116,7 @@ def instruct_dataset( dataset: _component_: torchtune.datasets.instruct_dataset source: json - data_files: my_dataset.json + data_files: my_dataset.json column_map: input: question output: answer diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 089f062da7..69dc8732c2 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -74,7 +74,7 @@ def forward( q = self.q_proj(x) q = q.view(b, s_x, self.num_heads, self.q_head_dim) q = q.transpose(1, 2) - + q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -83,12 +83,16 @@ def forward( kv = kv.view(b, s_x, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) kv = kv.transpose(1, 2) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_pe = q_pe.transpose(1, 2) + k_pe = k_pe.transpose(1, 2) q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) + q_pe = q_pe.transpose(1, 2) + k_pe = k_pe.transpose(1, 2) query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index a949584712..9666bf92fa 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -47,20 +47,13 @@ def deepseek_v3( moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, ): - def rope(x, input_pos=None): - return x - layers = [] + use_yarn = False + if use_yarn: + pass + else: + rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) + layers = [] for i in range(num_layers): - - # q is sometimes decomposed into A/B (if q_lora_rank) - # kv is *always* decomposed - - # when q is decomposed the norm is applied but - # not otherwise - in this case the norm - # should be applied after q a proj and before q b proj - - # for kv decomposition pos embeddings need to be extracted before - # projecting back up q_head_dim = qk_rope_head_dim + qk_nope_head_dim if q_lora_rank is None: q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index c5e4b6e7e7..74326e65c3 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -9,9 +9,10 @@ from typing import Optional class DeepseekV3MoE(nn.Module): - """This class implements the moe layer which is Mixture of Experts. Mixture of Experts - typically consists of a set of expert networks, alongside with a router, which directs input tokens - to the appropriate experts. See more details in https://arxiv.org/2401.0606. + """This class implements the Mixture of Experts (MoE) layer for DeepSeek V3. + This comprises a set of a router and a set of experts, which are typically smaller than MLP layers in standard + transformer models. The router is used to select a subset of experts for each token, and the selected experts are + then used to compute the output of the MoE layer. See more details in https://arxiv.org/2401.0606. This class is identical to :class:`~torchtune.modules.moe.moe.MoE`, except that it applies the router weighting scores to the *output* of the experts, rather than the input. diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index e2b64e5265..2231029957 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -1,123 +1,354 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn +import math +from typing import Optional -from typing import Any, Optional +# --- Helper Functions for YaRN --- -import torch -from torch import nn +def yarn_find_correction_dim(num_rotations: int, + dim: int, # Full head dimension + base: float = 10000.0, + original_max_position_embeddings: int = 2048) -> float: + """ + Calculates the dimension index (in the full dim space) at which a certain + number of full rotations occur. + """ + return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) +def yarn_find_correction_range(beta_fast: int, # Min number of rotations for high freqs + beta_slow: int, # Max number of rotations for low freqs + dim: int, # Full head dimension + base: float = 10000.0, + original_max_position_embeddings: int = 2048) -> tuple[int, int]: + """ + Finds the range of dimension indices [low_idx, high_idx] (in the dim//2 frequency space) + that correspond to the specified rotation counts. These define the ramp for YaRN's interpolation. + """ + # These are indices in the full dimension space (0 to dim-1) + low_idx_full_dim = math.floor(yarn_find_correction_dim(beta_fast, dim, base, original_max_position_embeddings)) + high_idx_full_dim = math.ceil(yarn_find_correction_dim(beta_slow, dim, base, original_max_position_embeddings)) -class YaRNRotaryPositionalEmbeddings(nn.Module): + # The ramp mask is applied to dim // 2 frequencies. + # Each frequency element corresponds to two dimensions. + # So, we need to map these full_dim indices to the frequency_dim (dim//2) space. + # An index 'd' in full_dim corresponds to 'd // 2' in frequency_dim. + # However, DeepSeek's code uses these bounds directly with a mask of length dim//2. + # This implies that 'low_idx_full_dim' and 'high_idx_full_dim' are treated as bounds + # for the elements of the frequency vector (which has length dim//2). + # Let's stick to that interpretation for consistency with the reference. + + # Clamp values to be within valid indices for an array of length dim // 2 + # (i.e., 0 to dim//2 - 1) + dim_half = dim // 2 + low_idx_for_mask = max(low_idx_full_dim, 0) # Should be max(low_idx_full_dim // 2, 0) if strictly mapping + high_idx_for_mask = min(high_idx_full_dim, dim_half -1) # Should be min(high_idx_full_dim // 2, dim_half -1) + + # DeepSeek's `yarn_find_correction_range` returns `max(low,0), min(high, dim-1)` + # and then `yarn_linear_ramp_mask` takes `dim//2` as its length. + # The `low` and `high` are used directly as bounds for the mask of length `dim//2`. + # This means `low` and `high` are effectively indices into the `dim//2` array. + # So, the clamping should be against `dim_half - 1`. + + # Re-evaluating based on deepseek_tt.py: + # yarn_find_correction_range(self.beta_fast, self.beta_slow, dim, ...) + # -> low, high + # yarn_linear_ramp_mask(low, high, dim // 2) + # This implies `low` and `high` from `yarn_find_correction_range` are directly + # used as bounds for the mask of length `dim // 2`. + # The `dim` passed to `yarn_find_correction_range` is the full head_dim. + # The `dim` passed to `yarn_linear_ramp_mask` is `head_dim // 2`. + # The `low` and `high` values from `yarn_find_correction_range` are indices + # that can range up to `head_dim - 1`. + # When used in `yarn_linear_ramp_mask(low, high, head_dim // 2)`, these `low` and `high` + # are used as the `min_val` and `max_val` for a ramp over `head_dim // 2` elements. + # This seems to imply a scaling or interpretation of `low` and `high` within the ramp function. + # Let's assume the `yarn_linear_ramp_mask` expects `min_val` and `max_val` to be + # meaningful indices *within the range of `num_dims_to_mask`*. + # The `low` and `high` from `yarn_find_correction_range` in deepseek_tt are indeed + # clamped against `dim-1` (full dim). + # The most direct interpretation from deepseek_tt is that `low` and `high` are used as is. + + return max(low_idx_full_dim, 0), min(high_idx_full_dim, dim -1) # Return bounds in full_dim space + + +def yarn_linear_ramp_mask(min_val: float, # Start boundary for the ramp (can be outside 0 to num_dims_to_mask-1) + max_val: float, # End boundary for the ramp + num_dims_to_mask: int # Length of the mask, e.g., head_dim // 2 + ) -> torch.Tensor: """ - This class implements Rotary Positional Embeddings (RoPE) - proposed in https://arxiv.org/abs/2104.09864. - - Reference implementation (used for correctness verfication) - can be found here: - https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 - - In this implementation we cache the embeddings for each position upto - ``max_seq_len`` by computing this during init. - - Args: - dim (int): Embedding dimension. This is usually set to the dim of each - head in the attention module computed as ``embed_dim // num_heads`` - max_seq_len (int): Maximum expected sequence length for the - model, if exceeded the cached freqs will be recomputed - base (int): The base for the geometric progression used to compute - the rotation angles + Creates a linear ramp mask. The ramp is from 0 to 1. + Values of torch.arange(num_dims_to_mask) < min_val will be 0. + Values > max_val will be 1. """ + if min_val == max_val: + max_val += 0.001 # Avoid division by zero + + # Create points for the ramp from 0 to num_dims_to_mask-1 + dim_indices = torch.arange(num_dims_to_mask, dtype=torch.float32) + + # Calculate the ramp + # (current_dim_index - ramp_start_point) / (ramp_end_point - ramp_start_point) + linear_func = (dim_indices - min_val) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) # Clamp values to be between 0 and 1 + return ramp_func +def yarn_get_mscale(scaling_factor: float = 1.0, mscale_hyperparam: float = 1.0) -> float: + """Calculates the magnitude scaling factor component for YaRN.""" + if scaling_factor <= 1.0: + return 1.0 + return 0.1 * mscale_hyperparam * math.log(scaling_factor) + 1.0 + +# --- RoPE Application Helpers --- +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def _apply_rotary_pos_emb(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """ + Applies RoPE to a tensor. + tensor: Shape (..., seq_len, head_dim) + cos, sin: Shape (..., seq_len, head_dim), broadcastable with tensor. + """ + return (tensor * cos) + (_rotate_half(tensor) * sin) + + +# --- YarnRotaryPositionalEmbeddings Class --- +class YarnRotaryPositionalEmbeddings(nn.Module): def __init__( self, - dim: int, - max_seq_len: int = 4096, - base: int = 10_000, - ) -> None: + head_dim: int, + max_position_embeddings: int = 4096, # New, extended max sequence length + base: float = 10000.0, + original_max_position_embeddings: int = 2048, + scaling_factor: float = 1.0, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + dtype: torch.dtype = torch.float32, + ): super().__init__() - self.dim = dim + if head_dim % 2 != 0: + raise ValueError("head_dim must be divisible by 2 for RoPE.") + + self.head_dim = head_dim + self.max_position_embeddings = max_position_embeddings # Target extended length self.base = base - self.max_seq_len = max_seq_len - self.rope_init() + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = scaling_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale_hyperparam = mscale + self.mscale_all_dim_hyperparam = mscale_all_dim + self.dtype = dtype + + self.inv_freq = self._calculate_yarn_inv_freq() + self.m_scale_factor = self._calculate_yarn_magnitude_scale() + + self.register_buffer("_inv_freq_buffer", self.inv_freq.to(self.dtype), persistent=False) + + self.cos_cached: Optional[torch.Tensor] = None + self.sin_cached: Optional[torch.Tensor] = None + self.max_seq_len_cached: int = 0 + + # Pre-cache up to the new max length if needed, or handle dynamically + self._build_cache(self.max_position_embeddings) - def rope_init(self): - theta = 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + def _calculate_yarn_inv_freq(self) -> torch.Tensor: + dim_half = self.head_dim // 2 + # Frequencies are calculated for dim_half elements + + freq_extra = 1.0 / ( + self.base ** (torch.arange(0, self.head_dim, 2, dtype=self.dtype) / self.head_dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor * (self.base ** (torch.arange(0, self.head_dim, 2, dtype=self.dtype) / self.head_dim)) ) - self.register_buffer("theta", theta, persistent=False) - self.build_rope_cache(self.max_seq_len) - def build_rope_cache(self, max_seq_len: int = 4096) -> None: - # Create position indexes `[0, 1, ..., max_seq_len - 1]` - seq_idx = torch.arange( - max_seq_len, dtype=self.theta.dtype, device=self.theta.device + # low_bound_for_ramp and high_bound_for_ramp are indices in the full head_dim space (0 to head_dim-1) + # These define where the ramp starts and ends relative to the original dimensions. + low_bound_for_ramp, high_bound_for_ramp = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.head_dim, # Full head_dim + self.base, + self.original_max_position_embeddings, + ) + + # The ramp_values are for the dim_half frequencies. + # yarn_linear_ramp_mask(min_val, max_val, num_dims_to_mask) + # min_val and max_val here are interpreted as points along the 0..num_dims_to_mask-1 axis. + # If low_bound_for_ramp is an index in full_dim, for the mask of length dim_half, + # the corresponding start point for the ramp is low_bound_for_ramp / 2. + # This detail is critical for how the ramp aligns with the dimensions. + # DeepSeek's code: yarn_linear_ramp_mask(low, high, dim // 2) + # This implies 'low' and 'high' (from yarn_find_correction_range on full 'dim') + # are directly used as the min_val and max_val for a ramp over 'dim // 2' elements. + ramp_values = yarn_linear_ramp_mask( + low_bound_for_ramp, # Use directly as per DeepSeek's pattern + high_bound_for_ramp, # Use directly + dim_half # The number of elements in the mask ) + + # Interpolation based on DeepSeek's YarnRotaryEmbedding: + # inv_freq = freq_inter * (1 - inv_freq_mask_deepseek) + freq_extra * inv_freq_mask_deepseek + # where inv_freq_mask_deepseek = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + # This means: inv_freq = freq_inter * ramp_values + freq_extra * (1.0 - ramp_values) + inv_freq_yarn = freq_inter * ramp_values + freq_extra * (1.0 - ramp_values) + return inv_freq_yarn + + def _calculate_yarn_magnitude_scale(self) -> float: + m_scale_numerator = yarn_get_mscale(self.scaling_factor, self.mscale_hyperparam) + + # If mscale_all_dim_hyperparam is 0.0, yarn_get_mscale will use 0.0 for its mscale_hyperparam, + # resulting in a factor of 1.0 if scaling_factor > 1.0. + m_scale_denominator = yarn_get_mscale(self.scaling_factor, self.mscale_all_dim_hyperparam) + + if abs(m_scale_denominator) < 1e-8: + m_scale_denominator = 1.0 + return m_scale_numerator / m_scale_denominator - # Outer product of theta and position index; output tensor has - # a shape of [max_seq_len, dim // 2] - idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + def _build_cache(self, seq_len: int): + if seq_len <= self.max_seq_len_cached and self.cos_cached is not None and self.sin_cached is not None: + return # Cache is already sufficient - # cache includes both the cos and sin components and so the output shape is - # [max_seq_len, dim // 2, 2] - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - self.register_buffer("cache", cache, persistent=False) + self.max_seq_len_cached = seq_len + + # Ensure inv_freq is on the correct device. It should be due to register_buffer. + current_device = self._inv_freq_buffer.device + inv_freq_to_use = self._inv_freq_buffer.to(current_device) - def forward( - self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None - ) -> torch.Tensor: + t = torch.arange(seq_len, device=current_device, dtype=self.dtype) + freqs = torch.outer(t, inv_freq_to_use) # Shape: (seq_len, head_dim // 2) + + # Create embeddings of shape (seq_len, head_dim) for cos and sin + # Each frequency in freqs corresponds to a pair of dimensions. + # Standard RoPE implementations often create cos/sin for head_dim//2 and then duplicate or interleave. + # DeepSeek's implementation (and HF's) often does: + # emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim) + # cos_cached = emb.cos() + # This means the same frequency is applied to two consecutive dimensions, but one gets cos, other sin via _rotate_half. + # For direct application with _apply_rotary_pos_emb, we need cos and sin to be (seq_len, head_dim) + # where cos[..., 0:dim//2] and cos[..., dim//2:dim] are derived from freqs, and similarly for sin. + # More precisely, for a dimension `j`, if `j` is even, `cos_part = cos(pos * inv_freq[j//2])`, + # if `j` is odd, `cos_part = cos(pos * inv_freq[j//2])`. + # And for sin, if `j` is even, `sin_part = sin(pos * inv_freq[j//2])`, + # if `j` is odd, `sin_part = sin(pos * inv_freq[j//2])`. + # This is what `torch.cat((freqs, freqs), dim=-1)` effectively prepares for `_apply_rotary_pos_emb`. + + emb = torch.cat((freqs, freqs), dim=-1) # Shape: (seq_len, head_dim) + + self.cos_cached = (emb.cos() * self.m_scale_factor).to(self.dtype) + self.sin_cached = (emb.sin() * self.m_scale_factor).to(self.dtype) + + # Update buffers if they exist, otherwise create them + if hasattr(self, '_cos_cached_buffer'): + self.register_buffer("_cos_cached_buffer", self.cos_cached, persistent=False) + self.register_buffer("_sin_cached_buffer", self.sin_cached, persistent=False) + else: # First time + self.register_buffer("_cos_cached_buffer", self.cos_cached, persistent=False) + self.register_buffer("_sin_cached_buffer", self.sin_cached, persistent=False) + + + def forward(self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: - x (torch.Tensor): input tensor with shape - ``[b, s, n_h, h_d]`` - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape [b, s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - + x (torch.Tensor): Input tensor, e.g., query (Q) or key (K). + Expected shape: (batch_size, num_heads, seq_len, head_dim). + input_pos (Optional[torch.Tensor]): Positions of tokens. + Shape: (batch_size, seq_len) or (seq_len,). + If None, assumes positions are [0, 1, ..., seq_len-1]. Returns: - torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` - - Notation used for tensor shapes: - - b: batch size - - s: sequence length - - n_h: num heads - - h_d: head dim + torch.Tensor: Rotated tensor with the same shape as x. """ - # input tensor has shape [b, s, n_h, h_d] - seq_len = x.size(1) + batch_size, num_heads, seq_len, head_dim_x = x.shape + assert head_dim_x == self.head_dim, "Input head_dim does not match module's head_dim" - # extract the values based on whether input_pos is set or not - rope_cache = ( - self.cache[:seq_len] if input_pos is None else self.cache[input_pos] - ) + self._build_cache(max(seq_len, self.max_seq_len_cached)) # Ensure cache is up-to-date for current seq_len - # reshape input; the last dimension is used for computing the output. - # Cast to float to match the reference implementation - # tensor has shape [b, s, n_h, h_d // 2, 2] - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - - # reshape the cache for broadcasting - # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, - # otherwise has shape [1, s, 1, h_d // 2, 2] - rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) - - # tensor has shape [b, s, n_h, h_d // 2, 2] - x_out = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] - + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) + if input_pos is None: + # Use positions [0, 1, ..., seq_len-1] + # Slice the cache: (seq_len, head_dim) + cos = self._cos_cached_buffer[:seq_len] + sin = self._sin_cached_buffer[:seq_len] + # Reshape for broadcasting: (1, 1, seq_len, head_dim) + cos = cos.view(1, 1, seq_len, self.head_dim) + sin = sin.view(1, 1, seq_len, self.head_dim) + else: + # input_pos shape: (batch_size, seq_len) or (seq_len for KV cache current token) + # Gather from cache: results in (batch_size, seq_len, head_dim) or (seq_len, head_dim) + cos = self._cos_cached_buffer[input_pos] + sin = self._sin_cached_buffer[input_pos] + # Reshape for broadcasting with x (bs, num_h, slen, hdim) + if cos.ndim == 2: # (slen, hdim) - typical for KV cache current token + cos = cos.view(1, 1, -1, self.head_dim) # (1, 1, slen_kv, hdim) + sin = sin.view(1, 1, -1, self.head_dim) # (1, 1, slen_kv, hdim) + elif cos.ndim == 3: # (bs, slen, hdim) + cos = cos.unsqueeze(1) # (bs, 1, slen, hdim) + sin = sin.unsqueeze(1) # (bs, 1, slen, hdim) + # If input_pos was (1, slen) for a single sample in batch with full sequence + elif cos.ndim == 2 and input_pos.ndim == 2 and input_pos.shape[0] == 1: # (1, slen, hdim) + cos = cos.unsqueeze(0).unsqueeze(1) # (1,1,slen,hdim) + sin = sin.unsqueeze(0).unsqueeze(1) + + + rotated_x = _apply_rotary_pos_emb(x, cos, sin) + return rotated_x + +if __name__ == '__main__': + # Example Usage + HEAD_DIM = 64 + MAX_EXTENDED_LEN = 1024 + ORIGINAL_MAX_LEN = 256 + SCALING_FACTOR = MAX_EXTENDED_LEN / ORIGINAL_MAX_LEN # s = 4.0 + + yarn_rope = YarnRotaryPositionalEmbeddings( + head_dim=HEAD_DIM, + max_position_embeddings=MAX_EXTENDED_LEN, + base=10000.0, + original_max_position_embeddings=ORIGINAL_MAX_LEN, + scaling_factor=SCALING_FACTOR, + beta_fast=32, + beta_slow=1, + mscale=1.0, + mscale_all_dim=0.0, # Common setting from DeepSeek config + dtype=torch.float32 + ) + + BATCH_SIZE = 2 + NUM_HEADS = 4 + SEQ_LEN_TEST = 512 + + # Dummy Q tensor: (bs, num_heads, seq_len, head_dim) + q_tensor = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN_TEST, HEAD_DIM) + + # 1. Test with implicit positions [0, ..., SEQ_LEN_TEST-1] + q_rotated_implicit_pos = yarn_rope(q_tensor) + print(f"Shape of Q after YaRN RoPE (implicit positions): {q_rotated_implicit_pos.shape}") + + # 2. Test with explicit positions (e.g., for packed sequences or KV cache) + # Example: first sample uses pos 0-511, second sample uses pos 100-611 + pos_ids_sample1 = torch.arange(SEQ_LEN_TEST) + pos_ids_sample2 = torch.arange(100, 100 + SEQ_LEN_TEST) + explicit_pos = torch.stack([pos_ids_sample1, pos_ids_sample2], dim=0) # (bs, seq_len) + + q_rotated_explicit_pos = yarn_rope(q_tensor, input_pos=explicit_pos) + print(f"Shape of Q after YaRN RoPE (explicit positions): {q_rotated_explicit_pos.shape}") - # tensor has shape [b, s, n_h, h_d] - x_out = x_out.flatten(3) - return x_out.type_as(x) + # 3. Test KV cache scenario (single new token position) + # Assume current token is at position 512 (0-indexed) + current_token_pos = torch.tensor([SEQ_LEN_TEST], dtype=torch.long) # Shape (1,) or (bs, 1) + # For a single token, seq_len in Q/K would be 1 + k_tensor_current = torch.randn(BATCH_SIZE, NUM_HEADS, 1, HEAD_DIM) + k_rotated_kv_cache = yarn_rope(k_tensor_current, input_pos=current_token_pos.unsqueeze(0).expand(BATCH_SIZE, -1)) # (bs, 1) + print(f"Shape of K for current token after YaRN RoPE (KV cache): {k_rotated_kv_cache.shape}") + # Test if cache rebuilds for longer sequence + SEQ_LEN_LONGER = MAX_EXTENDED_LEN + 100 + q_tensor_longer = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN_LONGER, HEAD_DIM) + print(f"Max cached before longer: {yarn_rope.max_seq_len_cached}") + q_rotated_longer = yarn_rope(q_tensor_longer) + print(f"Max cached after longer: {yarn_rope.max_seq_len_cached}") + print(f"Shape of Q after YaRN RoPE (longer sequence, cache rebuild): {q_rotated_longer.shape}") \ No newline at end of file diff --git a/torchtune/models/qwen2/_positional_embeddings.py b/torchtune/models/qwen2/_positional_embeddings.py index 61e8682783..b6265e6c9d 100644 --- a/torchtune/models/qwen2/_positional_embeddings.py +++ b/torchtune/models/qwen2/_positional_embeddings.py @@ -49,7 +49,7 @@ def rope_init(self): self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: - # Create position indexes `[0, 1, ..., max_seq_len - 1]` + # Create position indexes `[0, 1, ..., max_seq_len - 1] seq_idx = torch.arange( max_seq_len, dtype=self.theta.dtype, device=self.theta.device ) From d2ba574dfc73b388c9ea948f13a4c0b91297819b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 10:49:30 +0100 Subject: [PATCH 08/25] mem debugging --- torchtune/models/deepseek_v3/_component_builders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 9666bf92fa..23bbe2811d 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -50,10 +50,12 @@ def deepseek_v3( use_yarn = False if use_yarn: pass - else: - rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) + # else: + # rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) + rope = nn.Identity() layers = [] for i in range(num_layers): + print("layer idx, mps memory usage", i, torch.mps.current_allocated_memory() / 1024**3, "GB") q_head_dim = qk_rope_head_dim + qk_nope_head_dim if q_lora_rank is None: q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) From 3205dea6f0bfb2f435c33fddb5167b784dfd19ef Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 16:56:40 +0100 Subject: [PATCH 09/25] debugigng --- torchtune/models/deepseek_v3/_component_builders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 23bbe2811d..bdc2bb4e53 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -52,10 +52,11 @@ def deepseek_v3( pass # else: # rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) - rope = nn.Identity() + # rope = nn.Identity() + def rope(x, input_pos=None): + return x layers = [] for i in range(num_layers): - print("layer idx, mps memory usage", i, torch.mps.current_allocated_memory() / 1024**3, "GB") q_head_dim = qk_rope_head_dim + qk_nope_head_dim if q_lora_rank is None: q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) From 8cb5642a5f8a993ed09c985e53ee38a6b854e1a7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 27 May 2025 17:27:51 +0100 Subject: [PATCH 10/25] return to this commit if something breaks --- torchtune/models/deepseek_v3/_attention.py | 29 +++--- .../models/deepseek_v3/_component_builders.py | 17 ++-- .../models/deepseek_v3/_convert_weights.py | 60 +++++++------ torchtune/models/deepseek_v3/_experts.py | 90 +++++++++++++++++++ torchtune/models/deepseek_v3/_moe.py | 80 +++++++++++------ torchtune/modules/moe/experts.py | 2 +- torchtune/modules/moe/moe.py | 5 +- torchtune/modules/transformer.py | 8 +- 8 files changed, 214 insertions(+), 77 deletions(-) create mode 100644 torchtune/models/deepseek_v3/_experts.py diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 69dc8732c2..41bae08fd6 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -48,7 +48,7 @@ def __init__(self, self.cache_enabled = False self._attention_call = _sdpa_or_flex_attention() - + def forward( self, x: torch.Tensor, @@ -57,8 +57,8 @@ def forward( mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # import ipdb; ipdb.set_trace() + + # # import ipdb; ipdb.set_trace() # q is sometimes decomposed into A/B # kv is *always* decomposed @@ -74,7 +74,7 @@ def forward( q = self.q_proj(x) q = q.view(b, s_x, self.num_heads, self.q_head_dim) q = q.transpose(1, 2) - + q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -84,23 +84,21 @@ def forward( kv = kv.transpose(1, 2) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - q_pe = q_pe.transpose(1, 2) - k_pe = k_pe.transpose(1, 2) + # q_pe = q_pe.transpose(1, 2) + # k_pe = k_pe.transpose(1, 2) q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) - q_pe = q_pe.transpose(1, 2) - k_pe = k_pe.transpose(1, 2) + # q_pe = q_pe.transpose(1, 2) + # k_pe = k_pe.transpose(1, 2) query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe key_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe output = self._attention_call( query_states, @@ -111,6 +109,13 @@ def forward( is_causal=mask is None, scale=self.softmax_scale, ) + # print(f"attn output\n") + # print(f"\tshape: {output.shape}") + # print(f"\tmean: {output.mean()}") + # print(f"\tstd: {output.std()}") + # print(f"\tmin: {output.min()}") + # print(f"\tmax: {output.max()}") + # exit() # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index bdc2bb4e53..0f954b1fad 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torchtune.models.deepseek_v3._experts import DeepseekV3GroupedExperts from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear from torchtune.models.deepseek_v3._attention import DeepSeekV3Attention from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter, DeepseekV3MoE @@ -89,13 +90,8 @@ def rope(x, input_pos=None): is_moe = (moe_every_n_layers is None or (i + 1) % moe_every_n_layers == 0) and i >= first_moe_layer if is_moe: mlp_layer = DeepseekV3MoE( - experts=GroupedExperts( - dim=embed_dim, - hidden_dim=moe_hidden_dim, - num_experts=num_experts, - ), + experts=deepseek_v3_experts(num_experts, embed_dim, moe_hidden_dim), router=DeepSeekV3TokenChoiceTopKRouter( - gate=nn.Linear(embed_dim, num_experts, bias=False), dim=embed_dim, num_experts=num_experts, experts_per_token=experts_per_token, @@ -131,6 +127,15 @@ def rope(x, input_pos=None): output=output_proj, ) +def deepseek_v3_experts( + num_experts: int, + dim: int, + hidden_dim: int, +) -> nn.ModuleDict: + experts = nn.ModuleDict({ + str(i): deepseek_v3_mlp(dim, hidden_dim) for i in range(num_experts) + }) + return experts def deepseek_v3_mlp( dim: int, diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py index 33886716e8..7cea272fc0 100644 --- a/torchtune/models/deepseek_v3/_convert_weights.py +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -2,6 +2,7 @@ import torch from torchtune.models.convert_weights import get_mapped_key import regex as re +from typing import Dict # hf_model # DeepseekV3ForCausalLM( # (model): DeepseekV3Model( @@ -137,6 +138,7 @@ "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", "model.norm.weight": "norm.scale", + # attenion weights "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attn.q_proj.a.weight", @@ -152,13 +154,18 @@ "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + # mlp MoE expert weights + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.mlp.experts.{}.w1.weight", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.mlp.experts.{}.w3.weight", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.mlp.experts.{}.w2.weight", + # mlp MoE shared expert weights "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.mlp.shared_expert.w1.weight", "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.mlp.shared_expert.w3.weight", "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.mlp.shared_expert.w2.weight", # mlp MoE token router weights - "model.layers.{}.mlp.gate.weight": "layers.{}.mlp.router.gate.weight", + "model.layers.{}.mlp.gate.weight": "layers.{}.mlp.router.gate", "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.mlp.router.e_score_correction_bias", "lm_head.weight": "output.weight", @@ -166,32 +173,35 @@ } -def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - converted_state_dict = {} +def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: + try: + # Checks if there is a layer # in the key + if any(k.isdigit() for k in key.split(".")): + # Replace all numbers with "{}" to create key for lookup + abstract_key = re.sub(r"(\.\d+)", ".{}", key) + # Find all numbers in the key in order + layer_nums = re.findall(r"\d+", key) + new_key = mapping_dict[abstract_key] + # Format with all numbers + new_key = new_key.format(*layer_nums) + else: + new_key = mapping_dict[key] + except KeyError as e: + raise Exception( + f'Error converting the state dict. Found unexpected key: "{key}". ' + "Please make sure you're loading a checkpoint with the right format. " + ) from e - # first merge expert weights - expert_weights_grouped = defaultdict(lambda: defaultdict(list)) - expert_keys_processed = set() - for key, value in state_dict.items(): - expert_match = re.match( - r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", key) - if expert_match: - layer_idx = expert_match.group(1) - expert_idx = int(expert_match.group(2)) - proj_name_part = expert_match.group(3) + return new_key - expert_weights_grouped[layer_idx][proj_name_part].append((expert_idx, value)) - expert_keys_processed.add(key) - - for layer_idx, projections in expert_weights_grouped.items(): - for proj_type, weights_list in projections.items(): - weights_list.sort(key=lambda x: x[0]) - stacked_weights = torch.stack([w[1].transpose(0, 1) for w in weights_list], dim=0) - new_key = f"layers.{layer_idx}.mlp.experts.{proj_type}" - converted_state_dict[new_key] = stacked_weights +def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} for key, value in state_dict.items(): - if key not in expert_keys_processed and "rotary_emb.inv_freq" not in key: - new_key = get_mapped_key(key, _FROM_HF) - converted_state_dict[new_key] = value + # Skip keys that should be ignored (like rotary embeddings) + if "rotary_emb.inv_freq" in key: + continue + + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value return converted_state_dict diff --git a/torchtune/models/deepseek_v3/_experts.py b/torchtune/models/deepseek_v3/_experts.py new file mode 100644 index 0000000000..facba66fac --- /dev/null +++ b/torchtune/models/deepseek_v3/_experts.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch import nn +from torch.nn import functional as F + + +class DeepseekV3GroupedExperts(nn.Module): + """This class implements the grouped experts layer used in Mixture of Experts. Each expert + is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. + + This class is identical to :class:`~torchtune.modules.moe.experts.GroupedExperts`, except that it uses a + `ModuleDict` to store the gate, down, and up projection matrices for each expert, rather than a + combined `nn.Parameter`. + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + num_experts (int): Number of experts in this grouped experts layer. Default is 1. + activation (Callable): Activation function to use. Default is F.silu. + """ + + def __init__( + self, + *, + dim: int, + hidden_dim: int, + num_experts: int = 1, + activation: Callable = F.silu, + ): + super().__init__() + self.dim = dim + self.num_experts = num_experts + self.experts = nn.ModuleDict({ + f"expert_{i}": nn.Linear(dim, hidden_dim) for i in range(num_experts) + }) + self.experts_down = nn.ModuleDict({ + f"expert_{i}": nn.Linear(hidden_dim, dim) for i in range(num_experts) + }) + self.experts_up = nn.ModuleDict({ + f"expert_{i}": nn.Linear(dim, hidden_dim) for i in range(num_experts) + }) + self.act_fn = activation + + # TODO: force no inference mode as a hack to get around + # "Cannot set version_counter for inference tensor" + @torch.inference_mode(mode=False) + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor with shape ``(bsz * seq_len * experts_per_token, dim)`` + num_tokens_per_expert (torch.Tensor): Tensor with shape ``(num_experts,)`` + enumerating the number of tokens each expert receives + + Returns: + torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) + """ + + # # import ipdb; ipdb.set_trace() + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.gate_proj[expert_idx], + self.down_proj[expert_idx], + self.up_proj[expert_idx], + ) + h = self.act_fn(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + return out diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index 74326e65c3..8d68da8d08 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch.nn.functional as F import torch from torch import nn from typing import Optional + class DeepseekV3MoE(nn.Module): """This class implements the Mixture of Experts (MoE) layer for DeepSeek V3. This comprises a set of a router and a set of experts, which are typically smaller than MLP layers in standard @@ -26,7 +28,7 @@ class DeepseekV3MoE(nn.Module): def __init__( self, *, - experts: nn.Module, + experts: nn.ModuleDict, router: nn.Module, shared_expert: Optional[nn.Module] = None, ): @@ -43,6 +45,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ + # import ipdb + # ipdb.set_trace() b, s, dim = x.shape # top_scores and selected_indices shape (bs*slen*experts_per_token,) # num_tokens_per_expert shape (num_experts,) @@ -52,9 +56,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.router(x.reshape(b * s, dim)) - # shape (b*s*experts_per_token, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - + token_indices = token_indices.unsqueeze(1).expand(-1, dim) # shape (b*s*experts_per_token, dim) routed_input = torch.gather( x.view(-1, dim), @@ -63,25 +65,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # shape (b*s*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - routed_output = routed_output * top_scores.reshape(-1, 1) - + routed_input = torch.split(routed_input, split_size_or_sections=num_tokens_per_expert.tolist(), dim=0) + routed_output = [] + for expert_idx, x_expert in enumerate(routed_input): + if x_expert.numel() == 0: + routed_output.append(torch.zeros_like(x_expert)) + continue + routed_output.append(self.experts[str(expert_idx)](x_expert)) # import ipdb; ipdb.set_trace() - # shared expert - if self.shared_expert is not None: - out = self.shared_expert(x).reshape(b * s, dim) - else: - out = torch.zeros_like(x.reshape(b * s, dim)) + routed_output = torch.cat(routed_output, dim=0) + import ipdb; ipdb.set_trace() + routed_output = routed_output * top_scores.unsqueeze(-1) + + out = torch.zeros_like(x.reshape(b * s, dim)).to(routed_output.dtype) if routed_output.numel() > 0: out.scatter_add_(dim=0, index=token_indices, src=routed_output) - out = out.reshape(b, s, dim) + + out = out.view(b, s, dim).to(x.dtype) + + if self.shared_expert is not None: + out += self.shared_expert(x) + + print_stats("output after shared", out) + exit() return out class DeepSeekV3TokenChoiceTopKRouter(nn.Module): def __init__(self, - gate: nn.Module, dim: int, num_experts: int, experts_per_token: int, @@ -91,7 +102,6 @@ def __init__(self, routed_scaling_factor: float ): super().__init__() - self.gate = gate self.dim = dim self.num_experts = num_experts self.experts_per_token = experts_per_token @@ -100,30 +110,34 @@ def __init__(self, self.norm_topk_prob = norm_topk_prob self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = nn.Parameter(torch.rand((self.num_experts))) + self.gate = nn.Parameter(torch.empty((num_experts, dim))) def forward(self, x: torch.Tensor) -> torch.Tensor: n = x.shape[0] - logits = self.gate(x) + logits = F.linear(x.to(torch.float32), self.gate.to(torch.float32), None) + # logits = self.gate(x) # calculate scores for every expert in every group - # import ipdb; ipdb.set_trace() - scores = torch.sigmoid(logits.to(torch.float32)).to(x.dtype) - scores += self.e_score_correction_bias.unsqueeze(0) + # # import ipdb; ipdb.set_trace() + scores = torch.sigmoid(logits) + print_stats("scores", scores) + scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + print_stats("scores_for_choice", scores_for_choice) # now calculate scores for every group based on the # top 2 scores of experts within each group experts_per_group = self.num_experts // self.num_groups group_scores = ( - scores.view(n, self.num_groups, experts_per_group) + scores_for_choice.view(n, self.num_groups, experts_per_group) .topk(2, dim=-1)[0].sum(dim=-1) ) + print_stats("group_scores", group_scores) + # grab the topk_groups number of groups based # on the scores for each group calculated above - group_idxs = torch.topk( - group_scores, k=self.topk_groups, dim=-1, sorted=False)[ - 1 - ] + group_idxs = torch.topk(group_scores, k=self.topk_groups, dim=-1, sorted=False).indices + print_stats("group_idxs", group_idxs) # mask out all experts within groups which will not be considered group_mask = torch.zeros_like(group_scores, dtype=torch.bool) group_mask.scatter_(1, group_idxs, True) # [n, n_group] @@ -135,11 +149,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) .reshape(n, -1) ) - - masked_scores = scores.masked_fill( + # masked_scores = scores + self.e_score_correction_bias.unsqueeze(0) + masked_scores = scores_for_choice.masked_fill( ~score_mask, float('-inf') ) - + print_stats("masked_scores", masked_scores) # now select the top experts_per_token number of # experts based on experts within eligible groups _, selected_experts_idxs = torch.topk(masked_scores, k=self.experts_per_token, dim=-1, sorted=False) @@ -151,7 +165,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scores_per_token /= denominator # apply scaling factor - scores_per_token *= self.routed_scaling_factor + scores_per_token = scores_per_token * self.routed_scaling_factor num_tokens_per_expert = torch.histc( selected_experts_idxs.float(), bins=self.num_experts, min=0, max=self.num_experts - 1 @@ -162,7 +176,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) scores_per_expert = scores_per_token.view(-1)[token_idxs_experts_sorted] + print_stats("scores_per_expert", scores_per_expert) token_idxs_experts_sorted = ( token_idxs_experts_sorted // self.experts_per_token ) return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert + + +def print_stats(name, x: torch.Tensor): + print(f"--{name}--") + print(f"max: {x.max()}, min: {x.min()}, mean: {x.float().mean()}, std: {x.float().std()}") + print(f"shape: {x.shape}") + # import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index 02f89dbc6b..b14023857e 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -56,7 +56,7 @@ def forward( torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) """ - # import ipdb; ipdb.set_trace() + # # import ipdb; ipdb.set_trace() # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) x = torch.split( diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index f1423bf764..2167d67c85 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -85,7 +85,7 @@ class MoE(nn.Module): """This class implements the moe layer which is Mixture of Experts. Mixture of Experts typically consists of a set of expert networks, alongside with a router, which directs input tokens to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. - + Args: experts (nn.Module): experts module. @@ -137,13 +137,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: routed_output = self.experts(routed_input, num_tokens_per_expert) routed_output = routed_output * top_scores.reshape(-1, 1) - # import ipdb; ipdb.set_trace() # shared expert if self.shared_expert is not None: out = self.shared_expert(x).reshape(b * s, dim) else: out = torch.zeros_like(x.reshape(b * s, dim)) - if routed_output.numel() > 0: + if routed_output.numel() > 0: out.scatter_add_(dim=0, index=token_indices, src=routed_output) out = out.reshape(b, s, dim) return out diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 80e3cea782..9ef16b3d85 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -133,12 +133,18 @@ def forward( # Norm applied before the feedforward layer mlp_out = self.mlp(self.mlp_norm(h)) - + # import ipdb; ipdb.set_trace() # Residual connection; shape: [batch_size, seq_length, embed_dim] out = h + self.mlp_scale(mlp_out) return out +def print_stats(name, x: torch.Tensor): + print(f"--{name}--") + print(f"max: {x.max()}, min: {x.min()}, mean: {x.float().mean()}, std: {x.float().std()}") + print(f"shape: {x.shape}") + # import ipdb; ipdb.set_trace() + class TransformerCrossAttentionLayer(nn.Module): """ Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer. From fd126c1d08d84f3a4d991954911168910ed8208d Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 27 May 2025 17:32:56 +0100 Subject: [PATCH 11/25] reverting changes --- torchtune/models/deepseek_v3/_attention.py | 26 +- .../models/deepseek_v3/_component_builders.py | 7 +- torchtune/models/deepseek_v3/_experts.py | 2 - .../models/deepseek_v3/_model_builders.py | 5 - torchtune/models/deepseek_v3/_moe.py | 35 +- .../deepseek_v3/_position_embeddings.py | 354 ------------------ torchtune/modules/__init__.py | 2 + torchtune/modules/classifier.py | 12 +- torchtune/modules/kv_cache.py | 2 +- torchtune/modules/transformer.py | 25 +- torchtune/modules/vision_transformer.py | 3 + 11 files changed, 32 insertions(+), 441 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 41bae08fd6..8760c69db8 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -39,7 +39,6 @@ def __init__(self, self.is_causal = is_causal # Set layers - # self.kv_cache = kv_cache self.q_proj = q_proj self.kv_proj = kv_proj self.output_proj = output_proj @@ -57,19 +56,6 @@ def forward( mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # # import ipdb; ipdb.set_trace() - - # q is sometimes decomposed into A/B - # kv is *always* decomposed - - # when q is decomposed the norm is applied but - # not otherwise - in this case the norm - # should be applied after q a proj and before q b proj - - # for kv decomposition pos embeddings need to be extracted before - # projecting back up - b, s_x, _ = x.shape q = self.q_proj(x) q = q.view(b, s_x, self.num_heads, self.q_head_dim) @@ -85,13 +71,9 @@ def forward( k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # q_pe = q_pe.transpose(1, 2) - # k_pe = k_pe.transpose(1, 2) q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) - # q_pe = q_pe.transpose(1, 2) - # k_pe = k_pe.transpose(1, 2) query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim:] = q_pe @@ -109,13 +91,7 @@ def forward( is_causal=mask is None, scale=self.softmax_scale, ) - # print(f"attn output\n") - # print(f"\tshape: {output.shape}") - # print(f"\tmean: {output.mean()}") - # print(f"\tstd: {output.std()}") - # print(f"\tmin: {output.min()}") - # print(f"\tmax: {output.max()}") - # exit() + # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 0f954b1fad..127d52ac59 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -48,12 +48,9 @@ def deepseek_v3( moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, ): - use_yarn = False if use_yarn: - pass - # else: - # rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) - # rope = nn.Identity() + raise NotImplementedError("Yarn is not supported yet") + rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) def rope(x, input_pos=None): return x layers = [] diff --git a/torchtune/models/deepseek_v3/_experts.py b/torchtune/models/deepseek_v3/_experts.py index facba66fac..b27eda1f0e 100644 --- a/torchtune/models/deepseek_v3/_experts.py +++ b/torchtune/models/deepseek_v3/_experts.py @@ -64,8 +64,6 @@ def forward( Returns: torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) """ - - # # import ipdb; ipdb.set_trace() # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) x = torch.split( diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index 04ddc179ad..e69de29bb2 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -1,5 +0,0 @@ - -# def deepseek_v3_671b_256e( - -# ) -> TransformerDecoder: -# pass diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py index 8d68da8d08..6ab432c1c9 100644 --- a/torchtune/models/deepseek_v3/_moe.py +++ b/torchtune/models/deepseek_v3/_moe.py @@ -16,9 +16,6 @@ class DeepseekV3MoE(nn.Module): transformer models. The router is used to select a subset of experts for each token, and the selected experts are then used to compute the output of the MoE layer. See more details in https://arxiv.org/2401.0606. - This class is identical to :class:`~torchtune.modules.moe.moe.MoE`, except that it applies the - router weighting scores to the *output* of the experts, rather than the input. - Args: experts (nn.Module): experts module. router (nn.Module): router module. @@ -45,8 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - # import ipdb - # ipdb.set_trace() + b, s, dim = x.shape # top_scores and selected_indices shape (bs*slen*experts_per_token,) # num_tokens_per_expert shape (num_experts,) @@ -58,6 +54,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: token_indices = token_indices.unsqueeze(1).expand(-1, dim) # shape (b*s*experts_per_token, dim) + routed_input = torch.gather( x.view(-1, dim), dim=0, @@ -72,22 +69,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: routed_output.append(torch.zeros_like(x_expert)) continue routed_output.append(self.experts[str(expert_idx)](x_expert)) - # import ipdb; ipdb.set_trace() + routed_output = torch.cat(routed_output, dim=0) - import ipdb; ipdb.set_trace() routed_output = routed_output * top_scores.unsqueeze(-1) out = torch.zeros_like(x.reshape(b * s, dim)).to(routed_output.dtype) if routed_output.numel() > 0: out.scatter_add_(dim=0, index=token_indices, src=routed_output) - + out = out.view(b, s, dim).to(x.dtype) - + if self.shared_expert is not None: - out += self.shared_expert(x) + out = out + self.shared_expert(x) - print_stats("output after shared", out) - exit() return out @@ -115,14 +109,11 @@ def __init__(self, def forward(self, x: torch.Tensor) -> torch.Tensor: n = x.shape[0] logits = F.linear(x.to(torch.float32), self.gate.to(torch.float32), None) - # logits = self.gate(x) + # calculate scores for every expert in every group - # # import ipdb; ipdb.set_trace() scores = torch.sigmoid(logits) - print_stats("scores", scores) scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) - print_stats("scores_for_choice", scores_for_choice) # now calculate scores for every group based on the # top 2 scores of experts within each group experts_per_group = self.num_experts // self.num_groups @@ -131,13 +122,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: .topk(2, dim=-1)[0].sum(dim=-1) ) - print_stats("group_scores", group_scores) - # grab the topk_groups number of groups based # on the scores for each group calculated above group_idxs = torch.topk(group_scores, k=self.topk_groups, dim=-1, sorted=False).indices - print_stats("group_idxs", group_idxs) # mask out all experts within groups which will not be considered group_mask = torch.zeros_like(group_scores, dtype=torch.bool) group_mask.scatter_(1, group_idxs, True) # [n, n_group] @@ -153,7 +141,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: masked_scores = scores_for_choice.masked_fill( ~score_mask, float('-inf') ) - print_stats("masked_scores", masked_scores) # now select the top experts_per_token number of # experts based on experts within eligible groups _, selected_experts_idxs = torch.topk(masked_scores, k=self.experts_per_token, dim=-1, sorted=False) @@ -176,15 +163,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) scores_per_expert = scores_per_token.view(-1)[token_idxs_experts_sorted] - print_stats("scores_per_expert", scores_per_expert) token_idxs_experts_sorted = ( token_idxs_experts_sorted // self.experts_per_token ) return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert - - -def print_stats(name, x: torch.Tensor): - print(f"--{name}--") - print(f"max: {x.max()}, min: {x.min()}, mean: {x.float().mean()}, std: {x.float().std()}") - print(f"shape: {x.shape}") - # import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index 2231029957..e69de29bb2 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -1,354 +0,0 @@ -import torch -import torch.nn as nn -import math -from typing import Optional - -# --- Helper Functions for YaRN --- - -def yarn_find_correction_dim(num_rotations: int, - dim: int, # Full head dimension - base: float = 10000.0, - original_max_position_embeddings: int = 2048) -> float: - """ - Calculates the dimension index (in the full dim space) at which a certain - number of full rotations occur. - """ - return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - -def yarn_find_correction_range(beta_fast: int, # Min number of rotations for high freqs - beta_slow: int, # Max number of rotations for low freqs - dim: int, # Full head dimension - base: float = 10000.0, - original_max_position_embeddings: int = 2048) -> tuple[int, int]: - """ - Finds the range of dimension indices [low_idx, high_idx] (in the dim//2 frequency space) - that correspond to the specified rotation counts. These define the ramp for YaRN's interpolation. - """ - # These are indices in the full dimension space (0 to dim-1) - low_idx_full_dim = math.floor(yarn_find_correction_dim(beta_fast, dim, base, original_max_position_embeddings)) - high_idx_full_dim = math.ceil(yarn_find_correction_dim(beta_slow, dim, base, original_max_position_embeddings)) - - # The ramp mask is applied to dim // 2 frequencies. - # Each frequency element corresponds to two dimensions. - # So, we need to map these full_dim indices to the frequency_dim (dim//2) space. - # An index 'd' in full_dim corresponds to 'd // 2' in frequency_dim. - # However, DeepSeek's code uses these bounds directly with a mask of length dim//2. - # This implies that 'low_idx_full_dim' and 'high_idx_full_dim' are treated as bounds - # for the elements of the frequency vector (which has length dim//2). - # Let's stick to that interpretation for consistency with the reference. - - # Clamp values to be within valid indices for an array of length dim // 2 - # (i.e., 0 to dim//2 - 1) - dim_half = dim // 2 - low_idx_for_mask = max(low_idx_full_dim, 0) # Should be max(low_idx_full_dim // 2, 0) if strictly mapping - high_idx_for_mask = min(high_idx_full_dim, dim_half -1) # Should be min(high_idx_full_dim // 2, dim_half -1) - - # DeepSeek's `yarn_find_correction_range` returns `max(low,0), min(high, dim-1)` - # and then `yarn_linear_ramp_mask` takes `dim//2` as its length. - # The `low` and `high` are used directly as bounds for the mask of length `dim//2`. - # This means `low` and `high` are effectively indices into the `dim//2` array. - # So, the clamping should be against `dim_half - 1`. - - # Re-evaluating based on deepseek_tt.py: - # yarn_find_correction_range(self.beta_fast, self.beta_slow, dim, ...) - # -> low, high - # yarn_linear_ramp_mask(low, high, dim // 2) - # This implies `low` and `high` from `yarn_find_correction_range` are directly - # used as bounds for the mask of length `dim // 2`. - # The `dim` passed to `yarn_find_correction_range` is the full head_dim. - # The `dim` passed to `yarn_linear_ramp_mask` is `head_dim // 2`. - # The `low` and `high` values from `yarn_find_correction_range` are indices - # that can range up to `head_dim - 1`. - # When used in `yarn_linear_ramp_mask(low, high, head_dim // 2)`, these `low` and `high` - # are used as the `min_val` and `max_val` for a ramp over `head_dim // 2` elements. - # This seems to imply a scaling or interpretation of `low` and `high` within the ramp function. - # Let's assume the `yarn_linear_ramp_mask` expects `min_val` and `max_val` to be - # meaningful indices *within the range of `num_dims_to_mask`*. - # The `low` and `high` from `yarn_find_correction_range` in deepseek_tt are indeed - # clamped against `dim-1` (full dim). - # The most direct interpretation from deepseek_tt is that `low` and `high` are used as is. - - return max(low_idx_full_dim, 0), min(high_idx_full_dim, dim -1) # Return bounds in full_dim space - - -def yarn_linear_ramp_mask(min_val: float, # Start boundary for the ramp (can be outside 0 to num_dims_to_mask-1) - max_val: float, # End boundary for the ramp - num_dims_to_mask: int # Length of the mask, e.g., head_dim // 2 - ) -> torch.Tensor: - """ - Creates a linear ramp mask. The ramp is from 0 to 1. - Values of torch.arange(num_dims_to_mask) < min_val will be 0. - Values > max_val will be 1. - """ - if min_val == max_val: - max_val += 0.001 # Avoid division by zero - - # Create points for the ramp from 0 to num_dims_to_mask-1 - dim_indices = torch.arange(num_dims_to_mask, dtype=torch.float32) - - # Calculate the ramp - # (current_dim_index - ramp_start_point) / (ramp_end_point - ramp_start_point) - linear_func = (dim_indices - min_val) / (max_val - min_val) - ramp_func = torch.clamp(linear_func, 0, 1) # Clamp values to be between 0 and 1 - return ramp_func - -def yarn_get_mscale(scaling_factor: float = 1.0, mscale_hyperparam: float = 1.0) -> float: - """Calculates the magnitude scaling factor component for YaRN.""" - if scaling_factor <= 1.0: - return 1.0 - return 0.1 * mscale_hyperparam * math.log(scaling_factor) + 1.0 - -# --- RoPE Application Helpers --- -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - -def _apply_rotary_pos_emb(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - """ - Applies RoPE to a tensor. - tensor: Shape (..., seq_len, head_dim) - cos, sin: Shape (..., seq_len, head_dim), broadcastable with tensor. - """ - return (tensor * cos) + (_rotate_half(tensor) * sin) - - -# --- YarnRotaryPositionalEmbeddings Class --- -class YarnRotaryPositionalEmbeddings(nn.Module): - def __init__( - self, - head_dim: int, - max_position_embeddings: int = 4096, # New, extended max sequence length - base: float = 10000.0, - original_max_position_embeddings: int = 2048, - scaling_factor: float = 1.0, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1.0, - mscale_all_dim: float = 0.0, - dtype: torch.dtype = torch.float32, - ): - super().__init__() - if head_dim % 2 != 0: - raise ValueError("head_dim must be divisible by 2 for RoPE.") - - self.head_dim = head_dim - self.max_position_embeddings = max_position_embeddings # Target extended length - self.base = base - self.original_max_position_embeddings = original_max_position_embeddings - self.scaling_factor = scaling_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale_hyperparam = mscale - self.mscale_all_dim_hyperparam = mscale_all_dim - self.dtype = dtype - - self.inv_freq = self._calculate_yarn_inv_freq() - self.m_scale_factor = self._calculate_yarn_magnitude_scale() - - self.register_buffer("_inv_freq_buffer", self.inv_freq.to(self.dtype), persistent=False) - - self.cos_cached: Optional[torch.Tensor] = None - self.sin_cached: Optional[torch.Tensor] = None - self.max_seq_len_cached: int = 0 - - # Pre-cache up to the new max length if needed, or handle dynamically - self._build_cache(self.max_position_embeddings) - - def _calculate_yarn_inv_freq(self) -> torch.Tensor: - dim_half = self.head_dim // 2 - # Frequencies are calculated for dim_half elements - - freq_extra = 1.0 / ( - self.base ** (torch.arange(0, self.head_dim, 2, dtype=self.dtype) / self.head_dim) - ) - freq_inter = 1.0 / ( - self.scaling_factor * (self.base ** (torch.arange(0, self.head_dim, 2, dtype=self.dtype) / self.head_dim)) - ) - - # low_bound_for_ramp and high_bound_for_ramp are indices in the full head_dim space (0 to head_dim-1) - # These define where the ramp starts and ends relative to the original dimensions. - low_bound_for_ramp, high_bound_for_ramp = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - self.head_dim, # Full head_dim - self.base, - self.original_max_position_embeddings, - ) - - # The ramp_values are for the dim_half frequencies. - # yarn_linear_ramp_mask(min_val, max_val, num_dims_to_mask) - # min_val and max_val here are interpreted as points along the 0..num_dims_to_mask-1 axis. - # If low_bound_for_ramp is an index in full_dim, for the mask of length dim_half, - # the corresponding start point for the ramp is low_bound_for_ramp / 2. - # This detail is critical for how the ramp aligns with the dimensions. - # DeepSeek's code: yarn_linear_ramp_mask(low, high, dim // 2) - # This implies 'low' and 'high' (from yarn_find_correction_range on full 'dim') - # are directly used as the min_val and max_val for a ramp over 'dim // 2' elements. - ramp_values = yarn_linear_ramp_mask( - low_bound_for_ramp, # Use directly as per DeepSeek's pattern - high_bound_for_ramp, # Use directly - dim_half # The number of elements in the mask - ) - - # Interpolation based on DeepSeek's YarnRotaryEmbedding: - # inv_freq = freq_inter * (1 - inv_freq_mask_deepseek) + freq_extra * inv_freq_mask_deepseek - # where inv_freq_mask_deepseek = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - # This means: inv_freq = freq_inter * ramp_values + freq_extra * (1.0 - ramp_values) - inv_freq_yarn = freq_inter * ramp_values + freq_extra * (1.0 - ramp_values) - return inv_freq_yarn - - def _calculate_yarn_magnitude_scale(self) -> float: - m_scale_numerator = yarn_get_mscale(self.scaling_factor, self.mscale_hyperparam) - - # If mscale_all_dim_hyperparam is 0.0, yarn_get_mscale will use 0.0 for its mscale_hyperparam, - # resulting in a factor of 1.0 if scaling_factor > 1.0. - m_scale_denominator = yarn_get_mscale(self.scaling_factor, self.mscale_all_dim_hyperparam) - - if abs(m_scale_denominator) < 1e-8: - m_scale_denominator = 1.0 - return m_scale_numerator / m_scale_denominator - - def _build_cache(self, seq_len: int): - if seq_len <= self.max_seq_len_cached and self.cos_cached is not None and self.sin_cached is not None: - return # Cache is already sufficient - - self.max_seq_len_cached = seq_len - - # Ensure inv_freq is on the correct device. It should be due to register_buffer. - current_device = self._inv_freq_buffer.device - inv_freq_to_use = self._inv_freq_buffer.to(current_device) - - t = torch.arange(seq_len, device=current_device, dtype=self.dtype) - freqs = torch.outer(t, inv_freq_to_use) # Shape: (seq_len, head_dim // 2) - - # Create embeddings of shape (seq_len, head_dim) for cos and sin - # Each frequency in freqs corresponds to a pair of dimensions. - # Standard RoPE implementations often create cos/sin for head_dim//2 and then duplicate or interleave. - # DeepSeek's implementation (and HF's) often does: - # emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim) - # cos_cached = emb.cos() - # This means the same frequency is applied to two consecutive dimensions, but one gets cos, other sin via _rotate_half. - # For direct application with _apply_rotary_pos_emb, we need cos and sin to be (seq_len, head_dim) - # where cos[..., 0:dim//2] and cos[..., dim//2:dim] are derived from freqs, and similarly for sin. - # More precisely, for a dimension `j`, if `j` is even, `cos_part = cos(pos * inv_freq[j//2])`, - # if `j` is odd, `cos_part = cos(pos * inv_freq[j//2])`. - # And for sin, if `j` is even, `sin_part = sin(pos * inv_freq[j//2])`, - # if `j` is odd, `sin_part = sin(pos * inv_freq[j//2])`. - # This is what `torch.cat((freqs, freqs), dim=-1)` effectively prepares for `_apply_rotary_pos_emb`. - - emb = torch.cat((freqs, freqs), dim=-1) # Shape: (seq_len, head_dim) - - self.cos_cached = (emb.cos() * self.m_scale_factor).to(self.dtype) - self.sin_cached = (emb.sin() * self.m_scale_factor).to(self.dtype) - - # Update buffers if they exist, otherwise create them - if hasattr(self, '_cos_cached_buffer'): - self.register_buffer("_cos_cached_buffer", self.cos_cached, persistent=False) - self.register_buffer("_sin_cached_buffer", self.sin_cached, persistent=False) - else: # First time - self.register_buffer("_cos_cached_buffer", self.cos_cached, persistent=False) - self.register_buffer("_sin_cached_buffer", self.sin_cached, persistent=False) - - - def forward(self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Args: - x (torch.Tensor): Input tensor, e.g., query (Q) or key (K). - Expected shape: (batch_size, num_heads, seq_len, head_dim). - input_pos (Optional[torch.Tensor]): Positions of tokens. - Shape: (batch_size, seq_len) or (seq_len,). - If None, assumes positions are [0, 1, ..., seq_len-1]. - Returns: - torch.Tensor: Rotated tensor with the same shape as x. - """ - batch_size, num_heads, seq_len, head_dim_x = x.shape - assert head_dim_x == self.head_dim, "Input head_dim does not match module's head_dim" - - self._build_cache(max(seq_len, self.max_seq_len_cached)) # Ensure cache is up-to-date for current seq_len - - if input_pos is None: - # Use positions [0, 1, ..., seq_len-1] - # Slice the cache: (seq_len, head_dim) - cos = self._cos_cached_buffer[:seq_len] - sin = self._sin_cached_buffer[:seq_len] - # Reshape for broadcasting: (1, 1, seq_len, head_dim) - cos = cos.view(1, 1, seq_len, self.head_dim) - sin = sin.view(1, 1, seq_len, self.head_dim) - else: - # input_pos shape: (batch_size, seq_len) or (seq_len for KV cache current token) - # Gather from cache: results in (batch_size, seq_len, head_dim) or (seq_len, head_dim) - cos = self._cos_cached_buffer[input_pos] - sin = self._sin_cached_buffer[input_pos] - # Reshape for broadcasting with x (bs, num_h, slen, hdim) - if cos.ndim == 2: # (slen, hdim) - typical for KV cache current token - cos = cos.view(1, 1, -1, self.head_dim) # (1, 1, slen_kv, hdim) - sin = sin.view(1, 1, -1, self.head_dim) # (1, 1, slen_kv, hdim) - elif cos.ndim == 3: # (bs, slen, hdim) - cos = cos.unsqueeze(1) # (bs, 1, slen, hdim) - sin = sin.unsqueeze(1) # (bs, 1, slen, hdim) - # If input_pos was (1, slen) for a single sample in batch with full sequence - elif cos.ndim == 2 and input_pos.ndim == 2 and input_pos.shape[0] == 1: # (1, slen, hdim) - cos = cos.unsqueeze(0).unsqueeze(1) # (1,1,slen,hdim) - sin = sin.unsqueeze(0).unsqueeze(1) - - - rotated_x = _apply_rotary_pos_emb(x, cos, sin) - return rotated_x - -if __name__ == '__main__': - # Example Usage - HEAD_DIM = 64 - MAX_EXTENDED_LEN = 1024 - ORIGINAL_MAX_LEN = 256 - SCALING_FACTOR = MAX_EXTENDED_LEN / ORIGINAL_MAX_LEN # s = 4.0 - - yarn_rope = YarnRotaryPositionalEmbeddings( - head_dim=HEAD_DIM, - max_position_embeddings=MAX_EXTENDED_LEN, - base=10000.0, - original_max_position_embeddings=ORIGINAL_MAX_LEN, - scaling_factor=SCALING_FACTOR, - beta_fast=32, - beta_slow=1, - mscale=1.0, - mscale_all_dim=0.0, # Common setting from DeepSeek config - dtype=torch.float32 - ) - - BATCH_SIZE = 2 - NUM_HEADS = 4 - SEQ_LEN_TEST = 512 - - # Dummy Q tensor: (bs, num_heads, seq_len, head_dim) - q_tensor = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN_TEST, HEAD_DIM) - - # 1. Test with implicit positions [0, ..., SEQ_LEN_TEST-1] - q_rotated_implicit_pos = yarn_rope(q_tensor) - print(f"Shape of Q after YaRN RoPE (implicit positions): {q_rotated_implicit_pos.shape}") - - # 2. Test with explicit positions (e.g., for packed sequences or KV cache) - # Example: first sample uses pos 0-511, second sample uses pos 100-611 - pos_ids_sample1 = torch.arange(SEQ_LEN_TEST) - pos_ids_sample2 = torch.arange(100, 100 + SEQ_LEN_TEST) - explicit_pos = torch.stack([pos_ids_sample1, pos_ids_sample2], dim=0) # (bs, seq_len) - - q_rotated_explicit_pos = yarn_rope(q_tensor, input_pos=explicit_pos) - print(f"Shape of Q after YaRN RoPE (explicit positions): {q_rotated_explicit_pos.shape}") - - # 3. Test KV cache scenario (single new token position) - # Assume current token is at position 512 (0-indexed) - current_token_pos = torch.tensor([SEQ_LEN_TEST], dtype=torch.long) # Shape (1,) or (bs, 1) - # For a single token, seq_len in Q/K would be 1 - k_tensor_current = torch.randn(BATCH_SIZE, NUM_HEADS, 1, HEAD_DIM) - k_rotated_kv_cache = yarn_rope(k_tensor_current, input_pos=current_token_pos.unsqueeze(0).expand(BATCH_SIZE, -1)) # (bs, 1) - print(f"Shape of K for current token after YaRN RoPE (KV cache): {k_rotated_kv_cache.shape}") - - # Test if cache rebuilds for longer sequence - SEQ_LEN_LONGER = MAX_EXTENDED_LEN + 100 - q_tensor_longer = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN_LONGER, HEAD_DIM) - print(f"Max cached before longer: {yarn_rope.max_seq_len_cached}") - q_rotated_longer = yarn_rope(q_tensor_longer) - print(f"Max cached after longer: {yarn_rope.max_seq_len_cached}") - print(f"Shape of Q after YaRN RoPE (longer sequence, cache rebuild): {q_rotated_longer.shape}") \ No newline at end of file diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 29d857d032..2e25d424a1 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -34,6 +34,7 @@ ) from .vision_transformer import VisionTransformer from .vq_embeddings import VectorQuantizedEmbeddings +from .embedding_utils import resize_token_embeddings # usort: skip __all__ = [ "MultiHeadAttention", @@ -61,4 +62,5 @@ "prepare_layer_dropout", "classifier_model", "rms_norm", + "resize_token_embeddings", ] diff --git a/torchtune/modules/classifier.py b/torchtune/modules/classifier.py index 62363161fc..ac44f9ec25 100644 --- a/torchtune/modules/classifier.py +++ b/torchtune/modules/classifier.py @@ -10,6 +10,7 @@ from torchtune.config._utils import _get_component_from_path from torchtune.modules.transformer import TransformerDecoder + # TODO (SalmanMohammadi) - add a tutorial for fine-tuning classifiers def classifier_model( num_classes: int, base_model_path: str, **base_model_kwargs: Dict[str, Any] @@ -41,16 +42,11 @@ def classifier_model( """ model = _get_component_from_path(base_model_path)(**base_model_kwargs) + decoder = getattr(model, "decoder", model) if hasattr(model, "output"): - del model.output - model.output = nn.Linear( - model.head_dim * model.num_heads, num_classes, bias=False - ) - elif hasattr(model, "decoder") and hasattr(model.decoder, "output"): - del model.decoder.output - model.decoder.output = nn.Linear( - model.decoder.head_dim * model.decoder.num_heads, num_classes, bias=False + decoder.output = nn.Linear( + decoder.head_dim * decoder.num_heads, num_classes, bias=False ) else: raise ValueError("Could not find a valid output layer to adapt.") diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 3d72e87adc..366670e0b2 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -64,7 +64,7 @@ def update( already been filled, use ``.reset()``, which will reset the cache to the zero-th position. Example: - >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> cache = KVCache(batch_size=2, num_kv_heads=4, max_seq_len=16, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 9ef16b3d85..ae1fb14d07 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -11,6 +11,8 @@ from torchtune.modules import MultiHeadAttention from torchtune.modules.attention_utils import _MaskType +from torchtune.utils import deprecated + class TransformerSelfAttentionLayer(nn.Module): """ @@ -133,18 +135,12 @@ def forward( # Norm applied before the feedforward layer mlp_out = self.mlp(self.mlp_norm(h)) - # import ipdb; ipdb.set_trace() + # Residual connection; shape: [batch_size, seq_length, embed_dim] out = h + self.mlp_scale(mlp_out) return out -def print_stats(name, x: torch.Tensor): - print(f"--{name}--") - print(f"max: {x.max()}, min: {x.min()}, mean: {x.float().mean()}, std: {x.float().std()}") - print(f"shape: {x.shape}") - # import ipdb; ipdb.set_trace() - class TransformerCrossAttentionLayer(nn.Module): """ Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer. @@ -402,11 +398,13 @@ def __init__( self.head_dim = head_dim self.causal_mask = None self.num_output_chunks = 0 + self.skip_output_layer = False # attributes for KV caches during inference self.encoder_max_cache_seq_len = None self.decoder_max_cache_seq_len = None + @deprecated("Please use LinearCrossEntropyLoss instead") def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" @@ -491,7 +489,7 @@ def reset_caches(self): for layer in self.layers: layer.reset_cache() - @torch.compiler.disable + @deprecated("Please use self.skip_output_layer=True and use a linear loss instead") def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]: """ Apply output projection in chunks. This should be applied in conjunction with @@ -615,9 +613,9 @@ def forward( and skip straight to the transformer layers. Shape ``[b x s x d]``. Default: None Returns: - Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` or a list of layer - output tensors defined by ``output_hidden_states`` with the - final output tensor appended to the list. + Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_layer=False` + and ``[b x s x d]`` otherwise, or a list of layer output tensors defined by ``output_hidden_states`` with the + final output tensor appended to the list. Note: At the very first step of inference, when the model is provided with a prompt, @@ -680,8 +678,9 @@ def forward( def unembed(self, h): # shape: [b, s, d] h = self.norm(h) - - if self.num_output_chunks > 0: + if self.skip_output_layer: + output = h + elif self.num_output_chunks > 0: output = self.chunked_output(h) else: # shape: [b, seq_len, out_dim] diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index 6f261514b6..d80c5c7a17 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -381,6 +381,9 @@ def forward( h = x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) hidden_states.append(h) x = transformer_layer(x) + if len(self.layers) in self.out_indices: + h = x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) + hidden_states.append(h) # norm x = self.ln_post(x) From 5ff8127bea0752693099e64a2a92ec8e7cc9cf72 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 27 May 2025 17:35:12 +0100 Subject: [PATCH 12/25] more reverting --- torchtune/modules/classifier.py | 6 +- torchtune/modules/kv_cache.py | 6 +- torchtune/modules/moe/experts.py | 183 +++++++++++++++++++++++- torchtune/modules/moe/moe.py | 24 ++-- torchtune/modules/transformer.py | 26 ++-- torchtune/modules/vision_transformer.py | 11 +- 6 files changed, 212 insertions(+), 44 deletions(-) diff --git a/torchtune/modules/classifier.py b/torchtune/modules/classifier.py index ac44f9ec25..f70104ac88 100644 --- a/torchtune/modules/classifier.py +++ b/torchtune/modules/classifier.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Union +from typing import Any, Union import torch.nn as nn from torchtune.config._utils import _get_component_from_path @@ -13,7 +13,7 @@ # TODO (SalmanMohammadi) - add a tutorial for fine-tuning classifiers def classifier_model( - num_classes: int, base_model_path: str, **base_model_kwargs: Dict[str, Any] + num_classes: int, base_model_path: str, **base_model_kwargs: dict[str, Any] ) -> Union[TransformerDecoder, nn.Module]: """ Create a classifier model from a base model by adapting the output layer. @@ -26,7 +26,7 @@ def classifier_model( base_model_path (str): The path to the base model builder, which must return an instance of ``TransformerDecoder``, or a model with a ``decoder`` attribute that is an instance of ``TransformerDecoder``. - **base_model_kwargs (Dict[str, Any]): Keyword arguments for the base model. + **base_model_kwargs (dict[str, Any]): Keyword arguments for the base model. Returns: Union[TransformerDecoder, nn.Module]: The base model, with the output layer adapted for the number of classes. diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 366670e0b2..fe800a69c5 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple - import torch from torch import nn @@ -55,7 +53,7 @@ def size(self) -> int: def update( self, k_val: torch.Tensor, v_val: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache. Note: @@ -81,7 +79,7 @@ def update( v_val (torch.Tensor): Current value tensor with shape [B, H, S, D] Returns: - Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. + tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: ValueError: if the batch size of the new key (or value) tensor is greater than the batch size diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index b14023857e..8b7984c786 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Callable import torch from torch import nn from torch.nn import functional as F +from torchtune.modules.peft import AdapterModule class GroupedExperts(nn.Module): @@ -38,6 +40,13 @@ def __init__( self.up_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) self.act_fn = activation + def reset_parameters(self) -> None: + # Default initialization used by torch.nn.Linear + nn.init.kaiming_uniform_(self.gate_proj, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.down_proj, a=math.sqrt(5)) + if self.up_proj is not None: + nn.init.kaiming_uniform_(self.up_proj, a=math.sqrt(5)) + # TODO: force no inference mode as a hack to get around # "Cannot set version_counter for inference tensor" @torch.inference_mode(mode=False) @@ -53,10 +62,9 @@ def forward( enumerating the number of tokens each expert receives Returns: - torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) + torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)`` """ - # # import ipdb; ipdb.set_trace() # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) x = torch.split( @@ -66,9 +74,6 @@ def forward( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - if x_expert.numel() == 0: - out_experts_splits.append(torch.zeros_like(x_expert)) - continue w1, w2, w3 = ( self.gate_proj[expert_idx], self.down_proj[expert_idx], @@ -82,3 +87,171 @@ def forward( out = torch.cat(out_experts_splits, dim=0) return out + + +class LoRAGroupedExperts(nn.Module, AdapterModule): + """This class implements the grouped experts layer used in Mixture of Experts with additional LoRA + adapter parameters. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + rank (int): rank of the low-rank approximation + alpha (float): scaling factor for the low-rank approximation + dropout (float): dropout probability before LoRA layer. Default: 0.0 + num_experts (int): Number of experts in this grouped experts layer. Default is 1. + activation (Callable): Activation function to use. Default is F.silu. + """ + + def __init__( + self, + *, + dim: int, + hidden_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + num_experts: int = 1, + activation: Callable = F.silu, + ): + super().__init__() + self.dim = dim + self.num_experts = num_experts + self.rank = rank + self.alpha = alpha + self.gate_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.up_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.act_fn = activation + + # 'self.disabled' is a flag showing whether to turn off LoRA adapters, + # this can be used in DPO for treating the lora adapters as the policy model + # and disabling it to treat the base model as the reference model + self.disabled = False + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_gate_a = nn.Parameter(torch.empty(num_experts, dim, rank)) + self.lora_gate_b = nn.Parameter(torch.empty(num_experts, rank, hidden_dim)) + self.lora_down_a = nn.Parameter(torch.empty(num_experts, hidden_dim, rank)) + self.lora_down_b = nn.Parameter(torch.empty(num_experts, rank, dim)) + self.lora_up_a = nn.Parameter(torch.empty(num_experts, dim, rank)) + self.lora_up_b = nn.Parameter(torch.empty(num_experts, rank, hidden_dim)) + self.merged = False + self.initialize_parameters() + + def initialize_parameters(self) -> None: + # Default initialization used by torch.nn.Linear + nn.init.kaiming_uniform_(self.gate_proj, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.down_proj, a=math.sqrt(5)) + if self.up_proj is not None: + nn.init.kaiming_uniform_(self.up_proj, a=math.sqrt(5)) + + nn.init.kaiming_uniform_(self.lora_gate_a, a=math.sqrt(5)) + nn.init.zeros_(self.lora_gate_b) + nn.init.kaiming_uniform_(self.lora_down_a, a=math.sqrt(5)) + nn.init.zeros_(self.lora_down_b) + if self.lora_up_a is not None: + nn.init.kaiming_uniform_(self.lora_up_a, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up_b) + + def adapter_params(self) -> list[str]: + """ + Return a list of strings corresponding to the names of the ``nn.Parameter`` s in + the model coming from the adapter. + + For LoRA this means lora_gate, lora_up, lora_down a and b weights. + """ + # NOTE: this function has to be updated if the names of the lora parameters + # in this module change. + adapter_params = [ + "lora_gate_a", + "lora_gate_b", + "lora_down_a", + "lora_down_b", + "lora_up_a", + "lora_up_b", + ] + return adapter_params + + def _lora_tc_layer_forward( + self, + x: torch.Tensor, + base_weight: torch.Tensor, + lora_a_weight: torch.Tensor, + lora_b_weight: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass a single linear layer with lora adapter layers for Token Choice routing. + + Args: + x (torch.Tensor): Input tensor with shape ``(tokens_per_expert, in_dim)``. + base_weight (torch.Tensor): weight of the base linear projection, shape ``(in_dim, out_dim)``. + lora_a_weight (torch.Tensor): weight of the lora adapter A layer, shape ``(in_dim, rank)``. + lora_b_weight (torch.Tensor): weight of the lora adapter B layer, shape ``(rank, out_dim)``. + + Returns: + torch.Tensor: Output tensor with shape ``(tokens_per_expert, out_dim)``. + """ + out = torch.matmul(x, base_weight) + if self.disabled: + return out + lora_out = torch.matmul(self.dropout(x), lora_a_weight) + lora_out = (self.alpha / self.rank) * torch.matmul(lora_out, lora_b_weight) + return out + lora_out + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor with shape ``(bsz * seq_len * experts_per_token, dim)`` + num_tokens_per_expert (torch.Tensor): Tensor with shape ``(num_experts,)`` + enumerating the number of tokens each expert receives + + Returns: + torch.Tensor: tuple of input tensors each with shape ``(num_experts, tokens_per_expert, dim)`` for Token Choice(TC) + or a single tensor with shape (num_experts, tokens_per_expert, dim) for Expert Choice(EC). + """ + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + gate_proj, down_proj = ( + self.gate_proj[expert_idx], + self.down_proj[expert_idx], + ) + lora_gate_a, lora_gate_b, lora_down_a, lora_down_b = ( + self.lora_gate_a[expert_idx], + self.lora_gate_b[expert_idx], + self.lora_down_a[expert_idx], + self.lora_down_b[expert_idx], + ) + h = self.act_fn( + self._lora_tc_layer_forward( + x_expert, gate_proj, lora_gate_a, lora_gate_b + ) + ) + + if self.up_proj is not None: + up_proj = self.up_proj[expert_idx] + lora_up_a, lora_up_b = ( + self.lora_up_a[expert_idx], + self.lora_up_b[expert_idx], + ) + h = h * self._lora_tc_layer_forward( + x_expert, up_proj, lora_up_a, lora_up_b + ) + + h = self._lora_tc_layer_forward(h, down_proj, lora_down_a, lora_down_b) + + # h shape (tokens_per_expert(varying), hidden_dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + return out diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index 2167d67c85..b6fd008356 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -61,6 +61,7 @@ def forward( scores, k=self.experts_per_token, dim=1 ) self.selected_experts_indices = selected_experts_indices + # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( @@ -86,7 +87,6 @@ class MoE(nn.Module): typically consists of a set of expert networks, alongside with a router, which directs input tokens to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. - Args: experts (nn.Module): experts module. router (nn.Module): router module. @@ -113,36 +113,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - b, s, dim = x.shape + bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*experts_per_token,) # num_tokens_per_expert shape (num_experts,) ( top_scores, token_indices, num_tokens_per_expert, - ) = self.router(x.reshape(b * s, dim)) + ) = self.router(x.reshape(bs * slen, dim)) - # shape (b*s*experts_per_token, dim) + # shape (bs*slen*experts_per_token, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - # shape (b*s*experts_per_token, dim) + # shape (bs*slen*experts_per_token, dim) routed_input = torch.gather( x.view(-1, dim), dim=0, index=token_indices, ) - # routed_input = routed_input * top_scores.reshape(-1, 1) + routed_input = routed_input * top_scores.reshape(-1, 1) - # shape (b*s*top_k, dim) + # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) - routed_output = routed_output * top_scores.reshape(-1, 1) # shared expert if self.shared_expert is not None: - out = self.shared_expert(x).reshape(b * s, dim) + out = self.shared_expert(x).reshape(bs * slen, dim) else: - out = torch.zeros_like(x.reshape(b * s, dim)) - if routed_output.numel() > 0: - out.scatter_add_(dim=0, index=token_indices, src=routed_output) - out = out.reshape(b, s, dim) + out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) return out diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index ae1fb14d07..724138b14e 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Optional, Union import torch from torch import nn @@ -91,7 +91,7 @@ def forward( *, mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, - **kwargs: Dict, + **kwargs: dict, ) -> torch.Tensor: """ Args: @@ -115,7 +115,7 @@ def forward( of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. - **kwargs (Dict): transformer layer inputs not relevant to self attention. + **kwargs (dict): transformer layer inputs not relevant to self attention. Returns: torch.Tensor: output tensor with same shape as input @@ -258,7 +258,7 @@ def forward( *, encoder_input: Optional[torch.Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, - **kwargs: Dict, + **kwargs: dict, ) -> torch.Tensor: """ Args: @@ -270,7 +270,7 @@ def forward( tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence]. Default is None. - **kwargs (Dict): transformer layer inputs not relevant to self attention. + **kwargs (dict): transformer layer inputs not relevant to self attention. Returns: torch.Tensor: output tensor with same shape as input @@ -335,7 +335,7 @@ class TransformerDecoder(nn.Module): Args: tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move tokens to an embedding space. - layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single transformer Decoder layer, an + layers (Union[nn.Module, list[nn.Module], nn.ModuleList]): A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList. max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` @@ -350,7 +350,7 @@ class TransformerDecoder(nn.Module): the decoder. num_layers (Optional[int]): Number of Transformer Decoder layers, only define when layers is not a list. - output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output + output_hidden_states (Optional[list[int]]): list of layers (indices) to include in the output Raises: AssertionError: @@ -367,14 +367,14 @@ def __init__( self, *, tok_embeddings: nn.Embedding, - layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + layers: Union[nn.Module, list[nn.Module], nn.ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: nn.Module, output: Union[nn.Linear, Callable], num_layers: Optional[int] = None, - output_hidden_states: Optional[List[int]] = None, + output_hidden_states: Optional[list[int]] = None, ) -> None: super().__init__() if isinstance(layers, nn.ModuleList): @@ -490,7 +490,7 @@ def reset_caches(self): layer.reset_cache() @deprecated("Please use self.skip_output_layer=True and use a linear loss instead") - def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]: + def chunked_output(self, last_hidden_state: torch.Tensor) -> list[torch.Tensor]: """ Apply output projection in chunks. This should be applied in conjunction with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there. @@ -503,7 +503,7 @@ def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]: [b, seq_len, embed_dim]. Returns: - List[torch.Tensor]: List of num_chunks output tensors, each with shape + list[torch.Tensor]: List of num_chunks output tensors, each with shape [b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size. """ return [ @@ -580,7 +580,7 @@ def forward( encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, input_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: """ Args: tokens (Optional[torch.Tensor]): input tensor with shape ``[b x s]`` @@ -613,7 +613,7 @@ def forward( and skip straight to the transformer layers. Shape ``[b x s x d]``. Default: None Returns: - Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_layer=False` + Union[torch.Tensor, list[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_layer=False` and ``[b x s x d]`` otherwise, or a list of layer output tensors defined by ``output_hidden_states`` with the final output tensor appended to the list. diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index d80c5c7a17..0652d653e4 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple +from typing import Optional import torch from torch import nn @@ -181,7 +181,7 @@ class VisionTransformer(nn.Module): of shape (bsz * n_tiles, n_tokens, embed_dim) and output a tensor of shape (bsz * n_tiles, cls_output_dim). If provided, only the CLS token projection will be outputted, instead of all tokens. - out_indices (Optional[List[int]]): The indices of hidden layers to return. + out_indices (Optional[list[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will return the tokens before they go through the first and fourth layers. @@ -207,7 +207,7 @@ def __init__( pre_tile_pos_embed: Optional[nn.Module] = None, post_tile_pos_embed: Optional[nn.Module] = None, cls_projection: Optional[nn.Module] = None, - out_indices: Optional[List[int]] = None, + out_indices: Optional[list[int]] = None, in_channels: int = 3, append_cls_token: bool = False, ) -> None: @@ -260,7 +260,7 @@ def forward( self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + ) -> tuple[torch.Tensor, list[torch.Tensor]]: """ Processes images and returns the tokens and hidden states. @@ -282,7 +282,7 @@ def forward( Used to calculate the positional embeddings for the tiles. Returns: - Tuple[torch.Tensor, List[torch.Tensor]]: A tuple: (x, hidden_states), + tuple[torch.Tensor, list[torch.Tensor]]: A tuple: (x, hidden_states), where x is a torch.tensor of shape (bsz, n_imgs, n_tiles, n_tokens, embed_dim) and hidden_states has shape is a list of len(out_indices) torch.tensor with shape (bsz, n_imgs, n_tiles, n_tokens, embed_dim). @@ -424,7 +424,6 @@ def __init__(self, embed_dim: int, append_cls_token: bool = False) -> None: self.append_cls_token = append_cls_token def forward(self, x: torch.Tensor) -> torch.Tensor: - # add 1 CLS token to every tile bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape cls_emb = self.weight.broadcast_to(bsz_and_n_imgs, n_tiles, 1, embed_dim) From 9b841f4febd8528e4bb0f3439a5080c57c079224 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 27 May 2025 17:39:08 +0100 Subject: [PATCH 13/25] removing comments --- .../models/deepseek_v3/_convert_weights.py | 130 ------------------ 1 file changed, 130 deletions(-) diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py index 7cea272fc0..773472627b 100644 --- a/torchtune/models/deepseek_v3/_convert_weights.py +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -1,138 +1,8 @@ -from collections import defaultdict import torch from torchtune.models.convert_weights import get_mapped_key import regex as re from typing import Dict -# hf_model -# DeepseekV3ForCausalLM( -# (model): DeepseekV3Model( -# (embed_tokens): Identity() -# (layers): ModuleList( -# (0): DeepseekV3DecoderLayer( -# (self_attn): DeepseekV3Attention( -# (q_a_proj): Linear(in_features=16, out_features=16, bias=False) -# (q_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (q_b_proj): Linear(in_features=16, out_features=64, bias=False) -# (kv_a_proj_with_mqa): Linear(in_features=16, out_features=32, bias=False) -# (kv_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (kv_b_proj): Linear(in_features=16, out_features=64, bias=False) -# (o_proj): Linear(in_features=32, out_features=16, bias=False) -# ) -# (mlp): DeepseekV3MLP( -# (gate_proj): Linear(in_features=16, out_features=32, bias=False) -# (up_proj): Linear(in_features=16, out_features=32, bias=False) -# (down_proj): Linear(in_features=32, out_features=16, bias=False) -# (act_fn): SiLU() -# ) -# (input_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (post_attention_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# ) -# (1): DeepseekV3DecoderLayer( -# (self_attn): DeepseekV3Attention( -# (q_a_proj): Linear(in_features=16, out_features=16, bias=False) -# (q_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (q_b_proj): Linear(in_features=16, out_features=64, bias=False) -# (kv_a_proj_with_mqa): Linear(in_features=16, out_features=32, bias=False) -# (kv_a_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (kv_b_proj): Linear(in_features=16, out_features=64, bias=False) -# (o_proj): Linear(in_features=32, out_features=16, bias=False) -# ) -# (mlp): DeepseekV3MoE( -# (experts): ModuleList( -# (0-255): 256 x DeepseekV3MLP( -# (gate_proj): Linear(in_features=16, out_features=16, bias=False) -# (up_proj): Linear(in_features=16, out_features=16, bias=False) -# (down_proj): Linear(in_features=16, out_features=16, bias=False) -# (act_fn): SiLU() -# ) -# ) -# (gate): DeepseekV3TopkRouter() -# (shared_experts): DeepseekV3MLP( -# (gate_proj): Linear(in_features=16, out_features=16, bias=False) -# (up_proj): Linear(in_features=16, out_features=16, bias=False) -# (down_proj): Linear(in_features=16, out_features=16, bias=False) -# (act_fn): SiLU() -# ) -# ) -# (input_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (post_attention_layernorm): DeepseekV3RMSNorm((16,), eps=1e-06) -# ) -# ) -# (norm): DeepseekV3RMSNorm((16,), eps=1e-06) -# (rotary_emb): DeepseekV3RotaryEmbedding() -# ) -# (lm_head): Linear(in_features=16, out_features=129280, bias=False) -# ) -# TransformerDecoder( -# (tok_embeddings): Identity() -# (layers): ModuleList( -# (0): TransformerSelfAttentionLayer( -# (attn): DeepSeekV3Attention( -# (q_proj): DeepSeekV3LatentLinear( -# (a): Linear(in_features=16, out_features=16, bias=False) -# (b): Linear(in_features=16, out_features=64, bias=False) -# (norm): RMSNorm() -# ) -# (kv_proj): DeepSeekV3LatentLinear( -# (a): Linear(in_features=16, out_features=32, bias=False) -# (b): Linear(in_features=16, out_features=64, bias=False) -# (norm): RMSNorm() -# ) -# (output_proj): Linear(in_features=32, out_features=16, bias=False) -# (pos_embeddings): Identity() -# ) -# (mlp): FeedForward( -# (w1): Linear(in_features=16, out_features=32, bias=False) -# (w2): Linear(in_features=32, out_features=16, bias=False) -# (w3): Linear(in_features=16, out_features=32, bias=False) -# (activation): SiLU() -# ) -# (sa_norm): RMSNorm() -# (mlp_norm): RMSNorm() -# (sa_scale): Identity() -# (mlp_scale): Identity() -# ) -# (1): TransformerSelfAttentionLayer( -# (attn): DeepSeekV3Attention( -# (q_proj): DeepSeekV3LatentLinear( -# (a): Linear(in_features=16, out_features=16, bias=False) -# (b): Linear(in_features=16, out_features=64, bias=False) -# (norm): RMSNorm() -# ) -# (kv_proj): DeepSeekV3LatentLinear( -# (a): Linear(in_features=16, out_features=32, bias=False) -# (b): Linear(in_features=16, out_features=64, bias=False) -# (norm): RMSNorm() -# ) -# (output_proj): Linear(in_features=32, out_features=16, bias=False) -# (pos_embeddings): Identity() -# ) -# (mlp): MoE( -# (experts): GroupedExperts() -# (router): DeepSeekV3TokenChoiceTopKRouter( -# (gate): Linear(in_features=16, out_features=256, bias=False) -# ) -# (shared_expert): FeedForward( -# (w1): Linear(in_features=16, out_features=16, bias=False) -# (w2): Linear(in_features=16, out_features=16, bias=False) -# (w3): Linear(in_features=16, out_features=16, bias=False) -# (activation): SiLU() -# ) -# ) -# (sa_norm): RMSNorm() -# (mlp_norm): RMSNorm() -# (sa_scale): Identity() -# (mlp_scale): Identity() -# ) -# ) -# (norm): RMSNorm() -# (output): Linear(in_features=16, out_features=129280, bias=False) -# ) - -# state dict key mappings from HF's format to torchtune's format for DeepSeek V3 -# Note: Conversion might require custom logic beyond simple key mapping, -# especially for kv_proj and MoE expert weights. _FROM_HF = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", From 7a147327644c2434653ba3c2de42f6780d026e48 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Sun, 8 Jun 2025 11:42:20 -0700 Subject: [PATCH 14/25] adding yarn --- torchtune/datasets/_instruct.py | 2 +- .../models/deepseek_v3/_component_builders.py | 23 +- .../deepseek_v3/_position_embeddings.py | 462 ++++++++++++++++++ torchtune/models/llama4/_convert_weights.py | 2 +- .../models/llama4/_position_embeddings.py | 2 +- torchtune/models/qwen2/__init__.py | 2 +- torchtune/models/qwen2/_component_builders.py | 2 +- ..._embeddings.py => _position_embeddings.py} | 2 +- torchtune/modules/moe/moe.py | 1 - 9 files changed, 484 insertions(+), 14 deletions(-) rename torchtune/models/qwen2/{_positional_embeddings.py => _position_embeddings.py} (98%) diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index 43ff686756..eeb30feac0 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -116,7 +116,7 @@ def instruct_dataset( dataset: _component_: torchtune.datasets.instruct_dataset source: json - data_files: my_dataset.json + data_files: my_dataset.json column_map: input: question output: answer diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 127d52ac59..a563048801 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -18,9 +18,8 @@ TransformerDecoder, TransformerSelfAttentionLayer, ) -from torchtune.modules.moe.experts import GroupedExperts from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings - +from torchtune.models.deepseek_v3._position_embeddings import DeepSeekV3YarnRotaryEmbeddings def deepseek_v3( *, @@ -48,11 +47,21 @@ def deepseek_v3( moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, ): - if use_yarn: - raise NotImplementedError("Yarn is not supported yet") - rope = RotaryPositionalEmbeddings(dim=qk_rope_head_dim, max_seq_len=max_seq_len, base=rope_base) - def rope(x, input_pos=None): - return x + # if use_yarn: + # raise NotImplementedError("Yarn is not supported yet") + rope = DeepSeekV3YarnRotaryEmbeddings( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + base=rope_base, + scaling_factor=rope_scaling_factor, + original_max_seq_len=original_max_seq_len, + beta_fast=beta_fast, + beta_slow=beta_slow, + mscale=mscale, + mscale_all_dim=mscale_all_dim, + ) + # def rope(x, input_pos=None): + # return x layers = [] for i in range(num_layers): q_head_dim = qk_rope_head_dim + qk_nope_head_dim diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index e69de29bb2..bbb8f55c93 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +from torch import nn + + +class DeepSeekV3YarnRotaryEmbeddings(nn.Module): + """ + This class implements YaRN (Yet another RoPE extensioN) Rotary Positional Embeddings + for DeepSeek v3, proposed in https://arxiv.org/abs/2309.00071. + + YaRN extends RoPE to longer sequence lengths by selectively applying frequency scaling + to different parts of the frequency spectrum based on wavelength characteristics. + It also includes magnitude scaling to preserve attention patterns. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim // num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + scaling_factor (float): Factor by which to scale the original context length + original_max_seq_len (int): Original maximum sequence length before scaling + beta_fast (float): Lower bound for frequency scaling range. Default: 32 + beta_slow (float): Upper bound for frequency scaling range. Default: 1 + mscale (float): Magnitude scaling factor. Default: 1 + mscale_all_dim (float): Magnitude scaling for all dimensions. Default: 0 + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + scaling_factor: float = 1.0, + original_max_seq_len: int = 4096, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.scaling_factor = scaling_factor + self.original_max_seq_len = original_max_seq_len + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + self.rope_init() + + def _yarn_find_correction_dim( + self, num_rotations: float, dim: int, base: int, max_position_embeddings: int + ) -> float: + """Find dimension based on number of rotations using inverse formula.""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _yarn_find_correction_range( + self, low_rot: float, high_rot: float, dim: int, base: int, max_position_embeddings: int + ) -> tuple[int, int]: + """Find dimension range bounds based on rotations.""" + low = math.floor( + self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + def _yarn_get_mscale(self, scale: float = 1.0, mscale: float = 1.0) -> float: + """Calculate magnitude scaling factor.""" + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def _yarn_linear_ramp_mask(self, min_val: int, max_val: int, dim: int) -> torch.Tensor: + """Create linear ramp mask for smooth frequency interpolation.""" + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def rope_init(self): + """Initialize the YaRN RoPE embeddings.""" + # Compute base extrapolated freqs + freq_base = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + + # Compute scaled intre6-polated freqs + freq_interp = 1.0 / ( + self.scaling_factor + * self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + + # Find correction range for frequency interpolation + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.original_max_seq_len, + ) + + # Create interpolation mask + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, self.dim // 2) + + # Interpolate between scaled and unscaled frequencies + theta = freq_interp * (1 - inv_freq_mask) + freq_base * inv_freq_mask + + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + """Build the RoPE cache with YaRN scaling.""" + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # Calculate magnitude scaling + mscale = float( + self._yarn_get_mscale(self.scaling_factor, self.mscale) + / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([idx_theta.cos() * mscale, idx_theta.sin() * mscale], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Apply YaRN RoPE to input tensor. + + Args: + x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # Alternative: Use reference-style rotation for comparison + cos = rope_cache[..., 0] # [seq_len, dim//2] + sin = rope_cache[..., 1] # [seq_len, dim//2] + + # Expand for broadcasting: [1, seq_len, 1, dim//2] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + # Split input into two halves + x1 = x[..., : x.shape[-1] // 2] # [b, s, n_h, dim//2] + x2 = x[..., x.shape[-1] // 2:] # [b, s, n_h, dim//2] + + # Apply rotation + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + return torch.cat([rotated_x1, rotated_x2], dim=-1) + + +# Reference implementation for comparison +class ReferenceYarnRoPE(nn.Module): + """Reference implementation based on DeepSeek's YaRN RoPE""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + self._set_cos_sin_cache(max_position_embeddings, torch.device("cpu"), torch.float32) + + def yarn_find_correction_dim(self, num_rotations, dim, base, max_position_embeddings): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def yarn_find_correction_range(self, low_rot, high_rot, dim, base, max_position_embeddings): + low = math.floor( + self.yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self.yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + def yarn_get_mscale(self, scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def yarn_linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = self.yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - self.yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + self.yarn_get_mscale(self.scaling_factor, self.mscale) + / self.yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + def forward(self, x): + seq_len = x.shape[1] + cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim] + sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim] + + # Split the last dimension + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + + # Since cos and sin are duplicated (freqs concatenated with itself), + # we only need the first half for each rotation pair + cos_half = cos[..., : cos.shape[-1] // 2] # [1, seq_len, 1, dim//2] + sin_half = sin[..., : sin.shape[-1] // 2] # [1, seq_len, 1, dim//2] + + # Apply rotation + rotated_x1 = x1 * cos_half - x2 * sin_half + rotated_x2 = x1 * sin_half + x2 * cos_half + + return torch.cat([rotated_x1, rotated_x2], dim=-1) + + +def print_table(title, data): + """Print results in a nice table format""" + print(f"\n{title}") + print("=" * len(title)) + + # Find maximum width for each column + headers = list(data[0].keys()) + widths = [max(len(str(row[col])) for row in data + [dict(zip(headers, headers))]) for col in headers] + + # Print header + header_row = " | ".join(f"{headers[i]:<{widths[i]}}" for i in range(len(headers))) + print(header_row) + print("-" * len(header_row)) + + # Print data rows + for row in data: + data_row = " | ".join(f"{str(row[col]):<{widths[i]}}" for i, col in enumerate(headers)) + print(data_row) + + +if __name__ == "__main__": + print("Testing YaRN RoPE implementation...") + + # Test parameters + batch_size = 2 + seq_len = 512 + num_heads = 4 + head_dim = 64 + + # Create test input + torch.manual_seed(42) + x = torch.randn(batch_size, seq_len, num_heads, head_dim) + print(f"Input shape: {x.shape}") + + # Test configurations + test_configs = [ + {"scale": 1.0, "beta_fast": 32, "beta_slow": 1}, + {"scale": 2.0, "beta_fast": 32, "beta_slow": 1}, + {"scale": 4.0, "beta_fast": 32, "beta_slow": 1}, + {"scale": 2.0, "beta_fast": 16, "beta_slow": 2}, + ] + + results = [] + + for config in test_configs: + # Create our implementation + our_yarn = DeepSeekV3YarnRotaryEmbeddings( + dim=head_dim, + max_seq_len=1024, + scaling_factor=config["scale"], + original_max_seq_len=512, + beta_fast=config["beta_fast"], + beta_slow=config["beta_slow"], + mscale=1, + mscale_all_dim=0 + ) + + # Create reference implementation + ref_yarn = ReferenceYarnRoPE( + dim=head_dim, + max_position_embeddings=1024, + scaling_factor=config["scale"], + original_max_position_embeddings=512, + beta_fast=config["beta_fast"], + beta_slow=config["beta_slow"], + mscale=1, + mscale_all_dim=0 + ) + + # Run forward passes + our_output = our_yarn(x) + ref_output = ref_yarn(x) + + # Calculate metrics + freq_match = torch.allclose(our_yarn.theta, ref_yarn.inv_freq, atol=1e-6) + cos_match = torch.allclose(our_yarn.cache[:seq_len, :, 0], + ref_yarn.cos_cached[:seq_len, :head_dim // 2], atol=1e-6) + sin_match = torch.allclose(our_yarn.cache[:seq_len, :, 1], + ref_yarn.sin_cached[:seq_len, :head_dim // 2], atol=1e-6) + output_match = torch.allclose(our_output, ref_output, atol=1e-5) + + max_diff = (our_output - ref_output).abs().max().item() + mean_diff = (our_output - ref_output).abs().mean().item() + + results.append({ + "Scale": f"{config['scale']}x", + "Beta Range": f"[{config['beta_slow']}, {config['beta_fast']}]", + "Freq Match": "✓" if freq_match else "✗", + "Cos Match": "✓" if cos_match else "✗", + "Sin Match": "✓" if sin_match else "✗", + "Output Match": "✓" if output_match else "✗", + "Max Diff": f"{max_diff:.2e}", + "Mean Diff": f"{mean_diff:.2e}" + }) + + print_table("YaRN RoPE Comparison Results", results) + + # Detailed analysis for 2x scaling + print("\n" + "=" * 50) + print("DETAILED ANALYSIS FOR 2x SCALING") + print("=" * 50) + + our_yarn = DeepSeekV3YarnRotaryEmbeddings( + dim=head_dim, max_seq_len=1024, scaling_factor=2.0, + original_max_seq_len=512, beta_fast=32, beta_slow=1 + ) + ref_yarn = ReferenceYarnRoPE( + dim=head_dim, max_position_embeddings=1024, scaling_factor=2.0, + original_max_position_embeddings=512, beta_fast=32, beta_slow=1 + ) + + our_output = our_yarn(x) + ref_output = ref_yarn(x) + + analysis_data = [ + {"Component": "Theta/InvFreq", "Shape": str(our_yarn.theta.shape), "Match": "✓" if torch.allclose( + our_yarn.theta, ref_yarn.inv_freq, atol=1e-6) else "✗"}, + {"Component": "Cos Cache", "Shape": f"{our_yarn.cache.shape[0]}x{our_yarn.cache.shape[1]}", "Match": "✓" if torch.allclose( + our_yarn.cache[:seq_len, :, 0], ref_yarn.cos_cached[:seq_len, :head_dim // 2], atol=1e-6) else "✗"}, + {"Component": "Sin Cache", "Shape": f"{our_yarn.cache.shape[0]}x{our_yarn.cache.shape[1]}", "Match": "✓" if torch.allclose( + our_yarn.cache[:seq_len, :, 1], ref_yarn.sin_cached[:seq_len, :head_dim // 2], atol=1e-6) else "✗"}, + {"Component": "Final Output", "Shape": str(our_output.shape), "Match": "✓" if torch.allclose( + our_output, ref_output, atol=1e-5) else "✗"}, + ] + + print_table("Component Analysis", analysis_data) + + # Statistics comparison + stats_data = [ + {"Metric": "Mean", "Our Impl": f"{our_output.mean():.6f}", "Reference": f"{ref_output.mean():.6f}", + "Diff": f"{abs(our_output.mean() - ref_output.mean()):.2e}"}, + {"Metric": "Std", "Our Impl": f"{our_output.std():.6f}", "Reference": f"{ref_output.std():.6f}", + "Diff": f"{abs(our_output.std() - ref_output.std()):.2e}"}, + {"Metric": "Min", "Our Impl": f"{our_output.min():.6f}", "Reference": f"{ref_output.min():.6f}", + "Diff": f"{abs(our_output.min() - ref_output.min()):.2e}"}, + {"Metric": "Max", "Our Impl": f"{our_output.max():.6f}", "Reference": f"{ref_output.max():.6f}", + "Diff": f"{abs(our_output.max() - ref_output.max()):.2e}"}, + ] + + print_table("Output Statistics", stats_data) + + print(f"\nOverall Assessment:") + print(f"• Frequencies match: Perfect ✓") + print(f"• Cached values match: Perfect ✓") + print( + f"• Final outputs match: {'Perfect ✓' if torch.allclose(our_output, ref_output, atol=1e-5) else 'Close but not exact'}") + print(f"• Max difference: {(our_output - ref_output).abs().max():.2e}") + + if torch.allclose(our_output, ref_output, atol=1e-4): + print(f"• Assessment: ✅ EXCELLENT - Differences are within numerical precision") + elif torch.allclose(our_output, ref_output, atol=1e-3): + print(f"• Assessment: ✅ GOOD - Small differences, likely implementation variants") + else: + print(f"• Assessment: ⚠️ NEEDS INVESTIGATION - Significant differences detected") diff --git a/torchtune/models/llama4/_convert_weights.py b/torchtune/models/llama4/_convert_weights.py index 0d8de255a3..bf256e0128 100644 --- a/torchtune/models/llama4/_convert_weights.py +++ b/torchtune/models/llama4/_convert_weights.py @@ -234,7 +234,7 @@ def llama4_tune_to_hf( # Combine gate projection with up projection new_key = get_mapped_key(key, inverted_mapping_dict) up_proj = state_dict[key.replace("gate", "up")] - converted_state_dict[new_key] = torch.cat([value, up_proj], dim=-1 ) + converted_state_dict[new_key] = torch.cat([value, up_proj], dim=-1) continue elif key.endswith("experts.up_proj"): # Skip as already handled with gate projection diff --git a/torchtune/models/llama4/_position_embeddings.py b/torchtune/models/llama4/_position_embeddings.py index 3adfc32738..5f8b88a1d1 100644 --- a/torchtune/models/llama4/_position_embeddings.py +++ b/torchtune/models/llama4/_position_embeddings.py @@ -21,7 +21,7 @@ class Llama4ScaledRoPE(nn.Module): In this implementation we cache the embeddings for each position upto ``max_seq_len`` by computing this during init. - Note that this class is identical to :class:`~torchtune.models.llama4.Llama3ScaledRoPE`, but with different default values + Note that this class is identical to :class:`~torchtune.models.llama3_1.Llama3ScaledRoPE`, but with different default values for scaling factors, as set for Llama4 Scout. See the meta-llama reference code here: https://github.com/meta-llama/llama-models/blob/28fa4e3b287e84f6a6a92aab3c931f7479c827c1/models/llama4/args.py#L100-L107 diff --git a/torchtune/models/qwen2/__init__.py b/torchtune/models/qwen2/__init__.py index 8e04fba85d..169a67dab4 100644 --- a/torchtune/models/qwen2/__init__.py +++ b/torchtune/models/qwen2/__init__.py @@ -15,7 +15,7 @@ qwen2_7b, qwen2_tokenizer, ) -from ._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from ._position_embeddings import Qwen2RotaryPositionalEmbeddings from ._tokenizer import Qwen2Tokenizer __all__ = [ diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 45e0cb1d60..5ce667f0ed 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -10,7 +10,7 @@ from torch import nn from torchtune.modules.transformer import TransformerDecoder -from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from torchtune.models.qwen2._position_embeddings import Qwen2RotaryPositionalEmbeddings from torchtune.modules import ( MultiHeadAttention, diff --git a/torchtune/models/qwen2/_positional_embeddings.py b/torchtune/models/qwen2/_position_embeddings.py similarity index 98% rename from torchtune/models/qwen2/_positional_embeddings.py rename to torchtune/models/qwen2/_position_embeddings.py index b6265e6c9d..48b74ba8e4 100644 --- a/torchtune/models/qwen2/_positional_embeddings.py +++ b/torchtune/models/qwen2/_position_embeddings.py @@ -49,7 +49,7 @@ def rope_init(self): self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: - # Create position indexes `[0, 1, ..., max_seq_len - 1] + # Create position indexes [0, 1, ..., max_seq_len - 1] seq_idx = torch.arange( max_seq_len, dtype=self.theta.dtype, device=self.theta.device ) diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index b6fd008356..15b83dd13a 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -109,7 +109,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. - Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ From 0f899c1dbb1058889bd7f44d25d49dfd0af7b42f Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 12:41:24 -0700 Subject: [PATCH 15/25] YaRN got to be kiddin me --- .../deepseek_v3/_position_embeddings.py | 333 ++---------------- 1 file changed, 34 insertions(+), 299 deletions(-) diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index bbb8f55c93..821c4e418b 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -62,7 +62,6 @@ def __init__( def _yarn_find_correction_dim( self, num_rotations: float, dim: int, base: int, max_position_embeddings: int ) -> float: - """Find dimension based on number of rotations using inverse formula.""" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) @@ -70,7 +69,6 @@ def _yarn_find_correction_dim( def _yarn_find_correction_range( self, low_rot: float, high_rot: float, dim: int, base: int, max_position_embeddings: int ) -> tuple[int, int]: - """Find dimension range bounds based on rotations.""" low = math.floor( self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) ) @@ -80,28 +78,25 @@ def _yarn_find_correction_range( return max(low, 0), min(high, dim - 1) def _yarn_get_mscale(self, scale: float = 1.0, mscale: float = 1.0) -> float: - """Calculate magnitude scaling factor.""" if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def _yarn_linear_ramp_mask(self, min_val: int, max_val: int, dim: int) -> torch.Tensor: - """Create linear ramp mask for smooth frequency interpolation.""" if min_val == max_val: - max_val += 0.001 # Prevent singularity + max_val += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func def rope_init(self): - """Initialize the YaRN RoPE embeddings.""" # Compute base extrapolated freqs freq_base = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) ) - # Compute scaled intre6-polated freqs + # Compute scaled interpolated freqs freq_interp = 1.0 / ( self.scaling_factor * self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) @@ -126,7 +121,6 @@ def rope_init(self): self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: - """Build the RoPE cache with YaRN scaling.""" # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange( max_seq_len, dtype=self.theta.dtype, device=self.theta.device @@ -151,14 +145,23 @@ def forward( self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None ) -> torch.Tensor: """ - Apply YaRN RoPE to input tensor. - Args: - x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]`` + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. Returns: torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim """ # input tensor has shape [b, s, n_h, h_d] seq_len = x.size(1) @@ -168,295 +171,27 @@ def forward( self.cache[:seq_len] if input_pos is None else self.cache[input_pos] ) - # Alternative: Use reference-style rotation for comparison - cos = rope_cache[..., 0] # [seq_len, dim//2] - sin = rope_cache[..., 1] # [seq_len, dim//2] - - # Expand for broadcasting: [1, seq_len, 1, dim//2] - cos = cos.unsqueeze(0).unsqueeze(2) - sin = sin.unsqueeze(0).unsqueeze(2) - - # Split input into two halves - x1 = x[..., : x.shape[-1] // 2] # [b, s, n_h, dim//2] - x2 = x[..., x.shape[-1] // 2:] # [b, s, n_h, dim//2] - - # Apply rotation - rotated_x1 = x1 * cos - x2 * sin - rotated_x2 = x1 * sin + x2 * cos - - return torch.cat([rotated_x1, rotated_x2], dim=-1) - - -# Reference implementation for comparison -class ReferenceYarnRoPE(nn.Module): - """Reference implementation based on DeepSeek's YaRN RoPE""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1.0, - original_max_position_embeddings=4096, - beta_fast=32, - beta_slow=1, - mscale=1, - mscale_all_dim=0, - ): - super().__init__() - self.dim = dim - self.base = base - self.max_position_embeddings = max_position_embeddings - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale = mscale - self.mscale_all_dim = mscale_all_dim - self._set_cos_sin_cache(max_position_embeddings, torch.device("cpu"), torch.float32) - - def yarn_find_correction_dim(self, num_rotations, dim, base, max_position_embeddings): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - def yarn_find_correction_range(self, low_rot, high_rot, dim, base, max_position_embeddings): - low = math.floor( - self.yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - self.yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) - - def yarn_get_mscale(self, scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - def yarn_linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - dim = self.dim - - freq_extra = 1.0 / ( - self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) - freq_inter = 1.0 / ( - self.scaling_factor - * self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) - - low, high = self.yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - self.yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32 - ) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(seq_len, device=device, dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - - _mscale = float( - self.yarn_get_mscale(self.scaling_factor, self.mscale) - / self.yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False - ) - - def forward(self, x): - seq_len = x.shape[1] - cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim] - sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim] - - # Split the last dimension - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - - # Since cos and sin are duplicated (freqs concatenated with itself), - # we only need the first half for each rotation pair - cos_half = cos[..., : cos.shape[-1] // 2] # [1, seq_len, 1, dim//2] - sin_half = sin[..., : sin.shape[-1] // 2] # [1, seq_len, 1, dim//2] - - # Apply rotation - rotated_x1 = x1 * cos_half - x2 * sin_half - rotated_x2 = x1 * sin_half + x2 * cos_half - - return torch.cat([rotated_x1, rotated_x2], dim=-1) - - -def print_table(title, data): - """Print results in a nice table format""" - print(f"\n{title}") - print("=" * len(title)) - - # Find maximum width for each column - headers = list(data[0].keys()) - widths = [max(len(str(row[col])) for row in data + [dict(zip(headers, headers))]) for col in headers] - - # Print header - header_row = " | ".join(f"{headers[i]:<{widths[i]}}" for i in range(len(headers))) - print(header_row) - print("-" * len(header_row)) + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - # Print data rows - for row in data: - data_row = " | ".join(f"{str(row[col]):<{widths[i]}}" for i, col in enumerate(headers)) - print(data_row) + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) - -if __name__ == "__main__": - print("Testing YaRN RoPE implementation...") - - # Test parameters - batch_size = 2 - seq_len = 512 - num_heads = 4 - head_dim = 64 - - # Create test input - torch.manual_seed(42) - x = torch.randn(batch_size, seq_len, num_heads, head_dim) - print(f"Input shape: {x.shape}") - - # Test configurations - test_configs = [ - {"scale": 1.0, "beta_fast": 32, "beta_slow": 1}, - {"scale": 2.0, "beta_fast": 32, "beta_slow": 1}, - {"scale": 4.0, "beta_fast": 32, "beta_slow": 1}, - {"scale": 2.0, "beta_fast": 16, "beta_slow": 2}, - ] - - results = [] - - for config in test_configs: - # Create our implementation - our_yarn = DeepSeekV3YarnRotaryEmbeddings( - dim=head_dim, - max_seq_len=1024, - scaling_factor=config["scale"], - original_max_seq_len=512, - beta_fast=config["beta_fast"], - beta_slow=config["beta_slow"], - mscale=1, - mscale_all_dim=0 + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, ) - # Create reference implementation - ref_yarn = ReferenceYarnRoPE( - dim=head_dim, - max_position_embeddings=1024, - scaling_factor=config["scale"], - original_max_position_embeddings=512, - beta_fast=config["beta_fast"], - beta_slow=config["beta_slow"], - mscale=1, - mscale_all_dim=0 - ) - - # Run forward passes - our_output = our_yarn(x) - ref_output = ref_yarn(x) - - # Calculate metrics - freq_match = torch.allclose(our_yarn.theta, ref_yarn.inv_freq, atol=1e-6) - cos_match = torch.allclose(our_yarn.cache[:seq_len, :, 0], - ref_yarn.cos_cached[:seq_len, :head_dim // 2], atol=1e-6) - sin_match = torch.allclose(our_yarn.cache[:seq_len, :, 1], - ref_yarn.sin_cached[:seq_len, :head_dim // 2], atol=1e-6) - output_match = torch.allclose(our_output, ref_output, atol=1e-5) - - max_diff = (our_output - ref_output).abs().max().item() - mean_diff = (our_output - ref_output).abs().mean().item() - - results.append({ - "Scale": f"{config['scale']}x", - "Beta Range": f"[{config['beta_slow']}, {config['beta_fast']}]", - "Freq Match": "✓" if freq_match else "✗", - "Cos Match": "✓" if cos_match else "✗", - "Sin Match": "✓" if sin_match else "✗", - "Output Match": "✓" if output_match else "✗", - "Max Diff": f"{max_diff:.2e}", - "Mean Diff": f"{mean_diff:.2e}" - }) - - print_table("YaRN RoPE Comparison Results", results) - - # Detailed analysis for 2x scaling - print("\n" + "=" * 50) - print("DETAILED ANALYSIS FOR 2x SCALING") - print("=" * 50) - - our_yarn = DeepSeekV3YarnRotaryEmbeddings( - dim=head_dim, max_seq_len=1024, scaling_factor=2.0, - original_max_seq_len=512, beta_fast=32, beta_slow=1 - ) - ref_yarn = ReferenceYarnRoPE( - dim=head_dim, max_position_embeddings=1024, scaling_factor=2.0, - original_max_position_embeddings=512, beta_fast=32, beta_slow=1 - ) - - our_output = our_yarn(x) - ref_output = ref_yarn(x) - - analysis_data = [ - {"Component": "Theta/InvFreq", "Shape": str(our_yarn.theta.shape), "Match": "✓" if torch.allclose( - our_yarn.theta, ref_yarn.inv_freq, atol=1e-6) else "✗"}, - {"Component": "Cos Cache", "Shape": f"{our_yarn.cache.shape[0]}x{our_yarn.cache.shape[1]}", "Match": "✓" if torch.allclose( - our_yarn.cache[:seq_len, :, 0], ref_yarn.cos_cached[:seq_len, :head_dim // 2], atol=1e-6) else "✗"}, - {"Component": "Sin Cache", "Shape": f"{our_yarn.cache.shape[0]}x{our_yarn.cache.shape[1]}", "Match": "✓" if torch.allclose( - our_yarn.cache[:seq_len, :, 1], ref_yarn.sin_cached[:seq_len, :head_dim // 2], atol=1e-6) else "✗"}, - {"Component": "Final Output", "Shape": str(our_output.shape), "Match": "✓" if torch.allclose( - our_output, ref_output, atol=1e-5) else "✗"}, - ] - - print_table("Component Analysis", analysis_data) - - # Statistics comparison - stats_data = [ - {"Metric": "Mean", "Our Impl": f"{our_output.mean():.6f}", "Reference": f"{ref_output.mean():.6f}", - "Diff": f"{abs(our_output.mean() - ref_output.mean()):.2e}"}, - {"Metric": "Std", "Our Impl": f"{our_output.std():.6f}", "Reference": f"{ref_output.std():.6f}", - "Diff": f"{abs(our_output.std() - ref_output.std()):.2e}"}, - {"Metric": "Min", "Our Impl": f"{our_output.min():.6f}", "Reference": f"{ref_output.min():.6f}", - "Diff": f"{abs(our_output.min() - ref_output.min()):.2e}"}, - {"Metric": "Max", "Our Impl": f"{our_output.max():.6f}", "Reference": f"{ref_output.max():.6f}", - "Diff": f"{abs(our_output.max() - ref_output.max()):.2e}"}, - ] - - print_table("Output Statistics", stats_data) - - print(f"\nOverall Assessment:") - print(f"• Frequencies match: Perfect ✓") - print(f"• Cached values match: Perfect ✓") - print( - f"• Final outputs match: {'Perfect ✓' if torch.allclose(our_output, ref_output, atol=1e-5) else 'Close but not exact'}") - print(f"• Max difference: {(our_output - ref_output).abs().max():.2e}") - - if torch.allclose(our_output, ref_output, atol=1e-4): - print(f"• Assessment: ✅ EXCELLENT - Differences are within numerical precision") - elif torch.allclose(our_output, ref_output, atol=1e-3): - print(f"• Assessment: ✅ GOOD - Small differences, likely implementation variants") - else: - print(f"• Assessment: ⚠️ NEEDS INVESTIGATION - Significant differences detected") + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) From 3672173034d4cd7a73fa73a5d545d568b6572bb6 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 12:43:07 -0700 Subject: [PATCH 16/25] reverting --- torchtune/models/llama4/_position_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/models/llama4/_position_embeddings.py b/torchtune/models/llama4/_position_embeddings.py index 5f8b88a1d1..3adfc32738 100644 --- a/torchtune/models/llama4/_position_embeddings.py +++ b/torchtune/models/llama4/_position_embeddings.py @@ -21,7 +21,7 @@ class Llama4ScaledRoPE(nn.Module): In this implementation we cache the embeddings for each position upto ``max_seq_len`` by computing this during init. - Note that this class is identical to :class:`~torchtune.models.llama3_1.Llama3ScaledRoPE`, but with different default values + Note that this class is identical to :class:`~torchtune.models.llama4.Llama3ScaledRoPE`, but with different default values for scaling factors, as set for Llama4 Scout. See the meta-llama reference code here: https://github.com/meta-llama/llama-models/blob/28fa4e3b287e84f6a6a92aab3c931f7479c827c1/models/llama4/args.py#L100-L107 From 60555539eaea3cdd9f7c6e3fdca461a5ce42ac56 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 15:13:11 -0700 Subject: [PATCH 17/25] adding 'tiny' model --- .../models/deepseek_v3/_model_builders.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index e69de29bb2..2ec312d259 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.deepseek_v3._component_builders import deepseek_v3 + + +def deepseek_v3_6B_64e(): + """ + Builder for a DeepSeek V3 6.1B model with 64 experts. + https://huggingface.co/smohammadi/deepseek-v3-micro + """ + return deepseek_v3( + vocab_size=129280, + num_layers=16, + num_heads=32, + embed_dim=2048, + max_seq_len=32768, + mlp_hidden_dim=5632, + rope_base=10000, + norm_eps=1e-6, + moe_every_n_layers=1, + first_moe_layer=3, + moe_hidden_dim=1024, + num_experts=64, + num_shared_experts=1, + experts_per_token=8, + num_groups=8, + topk_groups=4, + norm_topk_prob=True, + routed_scaling_factor=2.5, + q_lora_rank=256, + kv_lora_rank=128, + qk_rope_head_dim=64, + qk_nope_head_dim=128, + v_head_dim=128, + rope_scaling_factor=40.0, + original_max_seq_len=4096, + beta_fast=32.0, + beta_slow=1.0, + mscale=1.0, + mscale_all_dim=1.0, + ) From 8d292b810d0e2fa9cb54b67792e950be7f96252a Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 15:18:26 -0700 Subject: [PATCH 18/25] fixing mscale --- torchtune/models/deepseek_v3/_attention.py | 4 ++-- .../models/deepseek_v3/_component_builders.py | 7 +++++-- .../models/deepseek_v3/_position_embeddings.py | 14 ++++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 8760c69db8..77d0630cf5 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -43,7 +43,7 @@ def __init__(self, self.kv_proj = kv_proj self.output_proj = output_proj self.pos_embeddings = pos_embeddings - self.softmax_scale = self.q_head_dim ** (-0.5) + self.softmax_scale = self.q_head_dim ** (-0.5) * self.pos_embeddings.mscale * self.pos_embeddings.mscale self.cache_enabled = False self._attention_call = _sdpa_or_flex_attention() @@ -91,7 +91,7 @@ def forward( is_causal=mask is None, scale=self.softmax_scale, ) - + # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index a563048801..846c1514be 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -29,6 +29,10 @@ def deepseek_v3( num_heads: int, max_seq_len: int, rope_base: int = 10_000, + rope_scaling_factor: Optional[float] = None, + original_max_seq_len: Optional[int] = None, + beta_fast: Optional[float] = None, + beta_slow: Optional[float] = None, q_lora_rank: Optional[int] = None, qk_rope_head_dim: Optional[int] = None, qk_nope_head_dim: Optional[int] = None, @@ -46,6 +50,7 @@ def deepseek_v3( mlp_hidden_dim: Optional[int] = None, moe_hidden_dim: Optional[int] = None, norm_eps: float = 1e-5, + ): # if use_yarn: # raise NotImplementedError("Yarn is not supported yet") @@ -57,8 +62,6 @@ def deepseek_v3( original_max_seq_len=original_max_seq_len, beta_fast=beta_fast, beta_slow=beta_slow, - mscale=mscale, - mscale_all_dim=mscale_all_dim, ) # def rope(x, input_pos=None): # return x diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index 821c4e418b..c1b8cdcaed 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -44,8 +44,6 @@ def __init__( original_max_seq_len: int = 4096, beta_fast: float = 32.0, beta_slow: float = 1.0, - mscale: float = 1.0, - mscale_all_dim: float = 0.0, ) -> None: super().__init__() self.dim = dim @@ -55,8 +53,7 @@ def __init__( self.original_max_seq_len = original_max_seq_len self.beta_fast = beta_fast self.beta_slow = beta_slow - self.mscale = mscale - self.mscale_all_dim = mscale_all_dim + self.mscale = None self.rope_init() def _yarn_find_correction_dim( @@ -120,6 +117,11 @@ def rope_init(self): self.register_buffer("theta", theta, persistent=False) self.build_rope_cache(self.max_seq_len) + @property + def mscale(self): + return self._yarn_get_mscale(self.scaling_factor, self.mscale) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange( @@ -131,14 +133,14 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() # Calculate magnitude scaling - mscale = float( + self.mscale = float( self._yarn_get_mscale(self.scaling_factor, self.mscale) / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) ) # cache includes both the cos and sin components and so the output shape is # [max_seq_len, dim // 2, 2] - cache = torch.stack([idx_theta.cos() * mscale, idx_theta.sin() * mscale], dim=-1) + cache = torch.stack([idx_theta.cos() * self.mscale, idx_theta.sin() * self.mscale], dim=-1) self.register_buffer("cache", cache, persistent=False) def forward( From f8e58800ab1c1c958d32608d63ae0fdcb788996c Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 15:27:59 -0700 Subject: [PATCH 19/25] adding weight conversion --- torchtune/models/deepseek_v3/_component_builders.py | 4 ---- torchtune/models/deepseek_v3/_convert_weights.py | 11 +++++++++++ torchtune/models/deepseek_v3/_model_builders.py | 2 -- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 846c1514be..65e7d00e58 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -52,8 +52,6 @@ def deepseek_v3( norm_eps: float = 1e-5, ): - # if use_yarn: - # raise NotImplementedError("Yarn is not supported yet") rope = DeepSeekV3YarnRotaryEmbeddings( dim=qk_rope_head_dim, max_seq_len=max_seq_len, @@ -63,8 +61,6 @@ def deepseek_v3( beta_fast=beta_fast, beta_slow=beta_slow, ) - # def rope(x, input_pos=None): - # return x layers = [] for i in range(num_layers): q_head_dim = qk_rope_head_dim + qk_nope_head_dim diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py index 773472627b..d4385e89ed 100644 --- a/torchtune/models/deepseek_v3/_convert_weights.py +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -75,3 +75,14 @@ def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, tor new_key = get_mapped_key(key, _FROM_HF) converted_state_dict[new_key] = value return converted_state_dict + + +def deepseek_v3_tune_to_hf(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items() + } + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index 2ec312d259..d35c4d03a4 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -40,6 +40,4 @@ def deepseek_v3_6B_64e(): original_max_seq_len=4096, beta_fast=32.0, beta_slow=1.0, - mscale=1.0, - mscale_all_dim=1.0, ) From 41e4359706c201a4d26c216ede6b983a7601874e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 18:46:31 -0700 Subject: [PATCH 20/25] fixes from testing dummy model --- torchtune/models/deepseek_v3/_attention.py | 4 ++- .../models/deepseek_v3/_component_builders.py | 4 +++ .../models/deepseek_v3/_model_builders.py | 2 ++ .../deepseek_v3/_position_embeddings.py | 34 +++++++++---------- torchtune/models/deepseek_v3/_tokenizer.py | 0 5 files changed, 25 insertions(+), 19 deletions(-) create mode 100644 torchtune/models/deepseek_v3/_tokenizer.py diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 77d0630cf5..10db089eb0 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -43,7 +43,9 @@ def __init__(self, self.kv_proj = kv_proj self.output_proj = output_proj self.pos_embeddings = pos_embeddings - self.softmax_scale = self.q_head_dim ** (-0.5) * self.pos_embeddings.mscale * self.pos_embeddings.mscale + self.softmax_scale = self.q_head_dim ** (-0.5) + mscale = self.pos_embeddings.get_mscale(self.pos_embeddings.scaling_factor, self.pos_embeddings.mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale self.cache_enabled = False self._attention_call = _sdpa_or_flex_attention() diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 65e7d00e58..5df1138e6f 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -33,6 +33,8 @@ def deepseek_v3( original_max_seq_len: Optional[int] = None, beta_fast: Optional[float] = None, beta_slow: Optional[float] = None, + mscale: Optional[float] = None, + mscale_all_dim: Optional[float] = None, q_lora_rank: Optional[int] = None, qk_rope_head_dim: Optional[int] = None, qk_nope_head_dim: Optional[int] = None, @@ -60,6 +62,8 @@ def deepseek_v3( original_max_seq_len=original_max_seq_len, beta_fast=beta_fast, beta_slow=beta_slow, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) layers = [] for i in range(num_layers): diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index d35c4d03a4..2ec312d259 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -40,4 +40,6 @@ def deepseek_v3_6B_64e(): original_max_seq_len=4096, beta_fast=32.0, beta_slow=1.0, + mscale=1.0, + mscale_all_dim=1.0, ) diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py index c1b8cdcaed..ba0925c73a 100644 --- a/torchtune/models/deepseek_v3/_position_embeddings.py +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -44,6 +44,8 @@ def __init__( original_max_seq_len: int = 4096, beta_fast: float = 32.0, beta_slow: float = 1.0, + mscale: float = 1.0, + mscale_all_dim: float = 1.0, ) -> None: super().__init__() self.dim = dim @@ -53,33 +55,34 @@ def __init__( self.original_max_seq_len = original_max_seq_len self.beta_fast = beta_fast self.beta_slow = beta_slow - self.mscale = None + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim self.rope_init() - def _yarn_find_correction_dim( + def _find_correction_dim( self, num_rotations: float, dim: int, base: int, max_position_embeddings: int ) -> float: return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) - def _yarn_find_correction_range( + def _find_correction_range( self, low_rot: float, high_rot: float, dim: int, base: int, max_position_embeddings: int ) -> tuple[int, int]: low = math.floor( - self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + self._find_correction_dim(low_rot, dim, base, max_position_embeddings) ) high = math.ceil( - self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + self._find_correction_dim(high_rot, dim, base, max_position_embeddings) ) return max(low, 0), min(high, dim - 1) - def _yarn_get_mscale(self, scale: float = 1.0, mscale: float = 1.0) -> float: + def get_mscale(self, scale: float = 1.0, mscale: float = 1.0) -> float: if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 - def _yarn_linear_ramp_mask(self, min_val: int, max_val: int, dim: int) -> torch.Tensor: + def _get_linear_ramp_mask(self, min_val: int, max_val: int, dim: int) -> torch.Tensor: if min_val == max_val: max_val += 0.001 @@ -100,7 +103,7 @@ def rope_init(self): ) # Find correction range for frequency interpolation - low, high = self._yarn_find_correction_range( + low, high = self._find_correction_range( self.beta_fast, self.beta_slow, self.dim, @@ -109,7 +112,7 @@ def rope_init(self): ) # Create interpolation mask - inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, self.dim // 2) + inv_freq_mask = 1.0 - self._get_linear_ramp_mask(low, high, self.dim // 2) # Interpolate between scaled and unscaled frequencies theta = freq_interp * (1 - inv_freq_mask) + freq_base * inv_freq_mask @@ -117,11 +120,6 @@ def rope_init(self): self.register_buffer("theta", theta, persistent=False) self.build_rope_cache(self.max_seq_len) - @property - def mscale(self): - return self._yarn_get_mscale(self.scaling_factor, self.mscale) - - def build_rope_cache(self, max_seq_len: int = 4096) -> None: # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange( @@ -133,14 +131,14 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() # Calculate magnitude scaling - self.mscale = float( - self._yarn_get_mscale(self.scaling_factor, self.mscale) - / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + mscale = float( + self.get_mscale(self.scaling_factor, self.mscale) + / self.get_mscale(self.scaling_factor, self.mscale_all_dim) ) # cache includes both the cos and sin components and so the output shape is # [max_seq_len, dim // 2, 2] - cache = torch.stack([idx_theta.cos() * self.mscale, idx_theta.sin() * self.mscale], dim=-1) + cache = torch.stack([idx_theta.cos() * mscale, idx_theta.sin() * mscale], dim=-1) self.register_buffer("cache", cache, persistent=False) def forward( diff --git a/torchtune/models/deepseek_v3/_tokenizer.py b/torchtune/models/deepseek_v3/_tokenizer.py new file mode 100644 index 0000000000..e69de29bb2 From 3c8c1b1adbc61c1a8f4dbc63e8ffd8be9f009e55 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 19:06:03 -0700 Subject: [PATCH 21/25] wip tokenizer --- .../6B_64e_full_single_device.yaml | 112 ++++++++++++++++++ torchtune/models/deepseek_v3/_tokenizer.py | 20 ++++ 2 files changed, 132 insertions(+) create mode 100644 recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml diff --git a/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml new file mode 100644 index 0000000000..0d7a6e6891 --- /dev/null +++ b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml @@ -0,0 +1,112 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/torchtune/qwen2_5_7B/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.dee.qwen2_5_tokenizer + path: /tmp/Qwen2.5-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2.5-7B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2_5.qwen2_5_7b_instruct + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2.5-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: False +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/torchtune/models/deepseek_v3/_tokenizer.py b/torchtune/models/deepseek_v3/_tokenizer.py index e69de29bb2..2b694e88e2 100644 --- a/torchtune/models/deepseek_v3/_tokenizer.py +++ b/torchtune/models/deepseek_v3/_tokenizer.py @@ -0,0 +1,20 @@ +from typing import Optional +from torchtune.modules.transforms.tokenizers import HuggingFaceBaseTokenizer, ModelTokenizer +from torchtune.modules.transforms import Transform + + +class DeepSeekV3Tokenizer(ModelTokenizer, Transform): + + def __init__(self, + path: str, + config_path: str, + max_seq_len: Optional[int] = None, + ) -> None: + + self.hf_tokenizer = HuggingFaceBaseTokenizer(path, tokenizer_config_json_path=config_path) + self.max_seq_len = max_seq_len + + @property + def vocab_size(self) -> int: + return self.hf_tokenizer.get_vocab_size() + From 2f1ee405568672925bcb065a6a71cce2a9c09f1e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 22:39:58 -0700 Subject: [PATCH 22/25] adding 'tiny' model --- torchtune/models/deepseek_v3/_tokenizer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchtune/models/deepseek_v3/_tokenizer.py b/torchtune/models/deepseek_v3/_tokenizer.py index 2b694e88e2..cd5e0b0bf8 100644 --- a/torchtune/models/deepseek_v3/_tokenizer.py +++ b/torchtune/models/deepseek_v3/_tokenizer.py @@ -18,3 +18,20 @@ def __init__(self, def vocab_size(self) -> int: return self.hf_tokenizer.get_vocab_size() + +if __name__ == "__main__": + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("smohammadi/deepseek-v3-micro") + text = "Hello, world!" + tokens = tokenizer.encode(text, add_special_tokens=True) + print(tokens) + print(tokenizer.decode(tokens)) + + tt_tokenizer = DeepSeekV3Tokenizer( + path="/Users/salmanmohammadi/projects/torchtune/target/deepseek/deepseek-v3-micro/tokenizer.json", + config_path="/Users/salmanmohammadi/projects/torchtune/target/deepseek/deepseek-v3-micro/tokenizer_config.json", + max_seq_len=1024 + ) + tt_tokens = tt_tokenizer.encode(text, add_bos=True, add_eos=True) + print(tt_tokens) + print(tt_tokenizer.decode(tt_tokens)) \ No newline at end of file From fe265dca40bbf2ad842969b4347cf0a33a454968 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 23:19:13 -0700 Subject: [PATCH 23/25] training running --- .../6B_64e_full_single_device.yaml | 39 ++++++++++--------- torchtune/models/deepseek_v3/_tokenizer.py | 28 ++++++++++--- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml index 0d7a6e6891..8c1113c1f2 100644 --- a/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml +++ b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml @@ -19,42 +19,45 @@ # # This config works only for training on single device. -output_dir: /tmp/torchtune/qwen2_5_7B/full_single_device # /tmp may be deleted by your system. Change it to your preference. +output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference. # Tokenizer tokenizer: - _component_: torchtune.models.dee.qwen2_5_tokenizer - path: /tmp/Qwen2.5-7B-Instruct/vocab.json - merges_file: /tmp/Qwen2.5-7B-Instruct/merges.txt - max_seq_len: null + _component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer + path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json + config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json + max_seq_len: 1024 # Dataset dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset + _component_: torchtune.datasets.text_completion_dataset + source: openai/gsm8k + column: question + name: main + split: train packed: False # True increases speed seed: null shuffle: True # Model Arguments model: - _component_: torchtune.models.qwen2_5.qwen2_5_7b_instruct + _component_: torchtune.models.deepseek_v3._model_builders.deepseek_v3_6B_64e checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Qwen2.5-7B-Instruct + checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro checkpoint_files: [ - model-00001-of-00004.safetensors, - model-00002-of-00004.safetensors, - model-00003-of-00004.safetensors, - model-00004-of-00004.safetensors, + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, ] recipe_checkpoint: null output_dir: ${output_dir} - model_type: QWEN2 + model_type: DEEPSEEK_V3 resume_from_checkpoint: False # Fine-tuning arguments -batch_size: 2 +batch_size: 1 epochs: 1 optimizer: _component_: torch.optim.AdamW @@ -68,21 +71,21 @@ clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment -device: cuda +device: mps # Memory management enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision -dtype: bf16 +dtype: fp32 # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.StdoutLogger log_dir: ${output_dir}/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True log_level: INFO # DEBUG, WARN, etc. diff --git a/torchtune/models/deepseek_v3/_tokenizer.py b/torchtune/models/deepseek_v3/_tokenizer.py index cd5e0b0bf8..768520978c 100644 --- a/torchtune/models/deepseek_v3/_tokenizer.py +++ b/torchtune/models/deepseek_v3/_tokenizer.py @@ -1,7 +1,7 @@ from typing import Optional from torchtune.modules.transforms.tokenizers import HuggingFaceBaseTokenizer, ModelTokenizer from torchtune.modules.transforms import Transform - +from functools import cached_property class DeepSeekV3Tokenizer(ModelTokenizer, Transform): @@ -10,7 +10,6 @@ def __init__(self, config_path: str, max_seq_len: Optional[int] = None, ) -> None: - self.hf_tokenizer = HuggingFaceBaseTokenizer(path, tokenizer_config_json_path=config_path) self.max_seq_len = max_seq_len @@ -18,6 +17,24 @@ def __init__(self, def vocab_size(self) -> int: return self.hf_tokenizer.get_vocab_size() + def encode(self, *args, **kwargs) -> list[int]: + return self.hf_tokenizer.encode(*args, **kwargs) + + def decode(self, *args, **kwargs) -> str: + return self.hf_tokenizer.decode(*args, **kwargs) + + @property + def bos_id(self) -> int: + return self.hf_tokenizer.bos_id + + @property + def eos_id(self) -> int: + return self.hf_tokenizer.eos_id + + @cached_property + def pad_id(self) -> int: + return self.hf_tokenizer.tokenizer.token_to_id(self.hf_tokenizer.config.get("pad_token")) + if __name__ == "__main__": from transformers import AutoTokenizer @@ -28,10 +45,11 @@ def vocab_size(self) -> int: print(tokenizer.decode(tokens)) tt_tokenizer = DeepSeekV3Tokenizer( - path="/Users/salmanmohammadi/projects/torchtune/target/deepseek/deepseek-v3-micro/tokenizer.json", - config_path="/Users/salmanmohammadi/projects/torchtune/target/deepseek/deepseek-v3-micro/tokenizer_config.json", + path="/Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json", + config_path="/Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json", max_seq_len=1024 ) tt_tokens = tt_tokenizer.encode(text, add_bos=True, add_eos=True) print(tt_tokens) - print(tt_tokenizer.decode(tt_tokens)) \ No newline at end of file + print(tt_tokenizer.decode(tt_tokens)) + import ipdb; ipdb.set_trace() \ No newline at end of file From 7d3c95e111df6730bb92b8ce431f61f2c2b68812 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 11 Jun 2025 23:39:09 -0700 Subject: [PATCH 24/25] WIP configurable rope --- torchtune/models/deepseek_v3/_attention.py | 5 ++-- .../models/deepseek_v3/_component_builders.py | 29 ++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py index 10db089eb0..d8b717d32f 100644 --- a/torchtune/models/deepseek_v3/_attention.py +++ b/torchtune/models/deepseek_v3/_attention.py @@ -44,8 +44,9 @@ def __init__(self, self.output_proj = output_proj self.pos_embeddings = pos_embeddings self.softmax_scale = self.q_head_dim ** (-0.5) - mscale = self.pos_embeddings.get_mscale(self.pos_embeddings.scaling_factor, self.pos_embeddings.mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale + if hasattr(self.pos_embeddings, "get_mscale"): + mscale = self.pos_embeddings.get_mscale(self.pos_embeddings.scaling_factor, self.pos_embeddings.mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale self.cache_enabled = False self._attention_call = _sdpa_or_flex_attention() diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py index 5df1138e6f..bdc1a45e63 100644 --- a/torchtune/models/deepseek_v3/_component_builders.py +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -54,17 +54,24 @@ def deepseek_v3( norm_eps: float = 1e-5, ): - rope = DeepSeekV3YarnRotaryEmbeddings( - dim=qk_rope_head_dim, - max_seq_len=max_seq_len, - base=rope_base, - scaling_factor=rope_scaling_factor, - original_max_seq_len=original_max_seq_len, - beta_fast=beta_fast, - beta_slow=beta_slow, - mscale=mscale, - mscale_all_dim=mscale_all_dim, - ) + if rope_scaling_factor: + rope = DeepSeekV3YarnRotaryEmbeddings( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + base=rope_base, + scaling_factor=rope_scaling_factor, + original_max_seq_len=original_max_seq_len, + beta_fast=beta_fast, + beta_slow=beta_slow, + mscale=mscale, + mscale_all_dim=mscale_all_dim, + ) + else: + rope = RotaryPositionalEmbeddings( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + base=rope_base, + ) layers = [] for i in range(num_layers): q_head_dim = qk_rope_head_dim + qk_nope_head_dim From eb510d328dbd29ace1f4628f6d26ab2dc5988c85 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 24 Jun 2025 14:56:01 +0100 Subject: [PATCH 25/25] updating max seq len and vocab size --- recipes/configs/deepseek_v3/moonlight.yaml | 113 ++++++++++++++++++ .../models/deepseek_v3/_model_builders.py | 31 ++++- 2 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 recipes/configs/deepseek_v3/moonlight.yaml diff --git a/recipes/configs/deepseek_v3/moonlight.yaml b/recipes/configs/deepseek_v3/moonlight.yaml new file mode 100644 index 0000000000..cbbc25a685 --- /dev/null +++ b/recipes/configs/deepseek_v3/moonlight.yaml @@ -0,0 +1,113 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer + path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json + config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json + max_seq_len: 1024 + +# Dataset +dataset: + _component_: torchtune.datasets.text_completion_dataset + source: openai/gsm8k + column: question + name: main + split: train + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.deepseek_v3._model_builders.moonlight_16B_64e + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/moonshot + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "27" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: DEEPSEEK_V3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: mps + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.StdoutLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py index 2ec312d259..76e410c0bd 100644 --- a/torchtune/models/deepseek_v3/_model_builders.py +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -17,7 +17,7 @@ def deepseek_v3_6B_64e(): num_layers=16, num_heads=32, embed_dim=2048, - max_seq_len=32768, + max_seq_len=163840, mlp_hidden_dim=5632, rope_base=10000, norm_eps=1e-6, @@ -43,3 +43,32 @@ def deepseek_v3_6B_64e(): mscale=1.0, mscale_all_dim=1.0, ) + + +def moonlight_16B_64e(): + return deepseek_v3( + vocab_size=163840, + num_layers=27, + num_heads=16, + embed_dim=2048, + max_seq_len=8192, + mlp_hidden_dim=11264, + rope_base=50000, + norm_eps=1e-5, + moe_every_n_layers=1, + first_moe_layer=1, + moe_hidden_dim=1408, + num_experts=64, + num_shared_experts=2, + experts_per_token=6, + num_groups=1, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=2.446, + q_lora_rank=None, + kv_lora_rank=512, + qk_rope_head_dim=64, + qk_nope_head_dim=128, + v_head_dim=128, + ) +