diff --git a/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml new file mode 100644 index 0000000000..8c1113c1f2 --- /dev/null +++ b/recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml @@ -0,0 +1,115 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer + path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json + config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json + max_seq_len: 1024 + +# Dataset +dataset: + _component_: torchtune.datasets.text_completion_dataset + source: openai/gsm8k + column: question + name: main + split: train + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.deepseek_v3._model_builders.deepseek_v3_6B_64e + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: DEEPSEEK_V3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: mps + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: fp32 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.StdoutLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/deepseek_v3/moonlight.yaml b/recipes/configs/deepseek_v3/moonlight.yaml new file mode 100644 index 0000000000..cbbc25a685 --- /dev/null +++ b/recipes/configs/deepseek_v3/moonlight.yaml @@ -0,0 +1,113 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer + path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json + config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json + max_seq_len: 1024 + +# Dataset +dataset: + _component_: torchtune.datasets.text_completion_dataset + source: openai/gsm8k + column: question + name: main + split: train + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.deepseek_v3._model_builders.moonlight_16B_64e + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/moonshot + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "27" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: DEEPSEEK_V3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: mps + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.StdoutLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/torchtune/models/deepseek_v3/__init__.py b/torchtune/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtune/models/deepseek_v3/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/models/deepseek_v3/_attention.py b/torchtune/models/deepseek_v3/_attention.py new file mode 100644 index 0000000000..d8b717d32f --- /dev/null +++ b/torchtune/models/deepseek_v3/_attention.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from typing import Optional +from torchtune.modules.attention_utils import _MaskType +from torchtune.modules import RMSNorm +from torchtune.modules.attention import _sdpa_or_flex_attention + + +class DeepSeekV3Attention(nn.Module): + def __init__(self, + embed_dim: int, + num_heads: int, + qk_rope_head_dim: int, + v_head_dim: int, + qk_nope_head_dim: int, + q_head_dim: int, + q_proj: nn.Module, + kv_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: nn.Module, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0,): + super().__init__() + self.num_heads = num_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.q_head_dim = q_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.q_proj = q_proj + self.kv_proj = kv_proj + self.output_proj = output_proj + self.pos_embeddings = pos_embeddings + self.softmax_scale = self.q_head_dim ** (-0.5) + if hasattr(self.pos_embeddings, "get_mscale"): + mscale = self.pos_embeddings.get_mscale(self.pos_embeddings.scaling_factor, self.pos_embeddings.mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + self.cache_enabled = False + + self._attention_call = _sdpa_or_flex_attention() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + b, s_x, _ = x.shape + q = self.q_proj(x) + q = q.view(b, s_x, self.num_heads, self.q_head_dim) + q = q.transpose(1, 2) + + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + kv, k_pe = self.kv_proj(x) + kv = kv.view(b, s_x, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + kv = kv.transpose(1, 2) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_pe = self.pos_embeddings(q_pe, input_pos=input_pos) + k_pe = self.pos_embeddings(k_pe, input_pos=input_pos) + + query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe + + key_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + + output = self._attention_call( + query_states, + key_states, + value_states, + mask=mask, + dropout_p=self.attn_dropout if self.training else 0.0, + is_causal=mask is None, + scale=self.softmax_scale, + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + + return self.output_proj(output) diff --git a/torchtune/models/deepseek_v3/_component_builders.py b/torchtune/models/deepseek_v3/_component_builders.py new file mode 100644 index 0000000000..bdc1a45e63 --- /dev/null +++ b/torchtune/models/deepseek_v3/_component_builders.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn +from torchtune.models.deepseek_v3._experts import DeepseekV3GroupedExperts +from torchtune.models.deepseek_v3._linear import DeepSeekV3LatentLinear +from torchtune.models.deepseek_v3._attention import DeepSeekV3Attention +from torchtune.models.deepseek_v3._moe import DeepSeekV3TokenChoiceTopKRouter, DeepseekV3MoE +from torchtune.modules import ( + FeedForward, + RMSNorm, + TransformerDecoder, + TransformerSelfAttentionLayer, +) +from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings +from torchtune.models.deepseek_v3._position_embeddings import DeepSeekV3YarnRotaryEmbeddings + +def deepseek_v3( + *, + vocab_size: int, + embed_dim: int, + num_layers: int, + num_heads: int, + max_seq_len: int, + rope_base: int = 10_000, + rope_scaling_factor: Optional[float] = None, + original_max_seq_len: Optional[int] = None, + beta_fast: Optional[float] = None, + beta_slow: Optional[float] = None, + mscale: Optional[float] = None, + mscale_all_dim: Optional[float] = None, + q_lora_rank: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + kv_lora_rank: Optional[int] = None, + v_head_dim: Optional[int] = None, + moe_every_n_layers: Optional[int] = None, + first_moe_layer: Optional[int] = None, + num_experts: Optional[int] = None, + num_shared_experts: Optional[int] = None, + num_groups: Optional[int] = None, + topk_groups: Optional[int] = None, + norm_topk_prob: Optional[float] = None, + routed_scaling_factor: Optional[float] = None, + experts_per_token: Optional[float] = None, + mlp_hidden_dim: Optional[int] = None, + moe_hidden_dim: Optional[int] = None, + norm_eps: float = 1e-5, + +): + if rope_scaling_factor: + rope = DeepSeekV3YarnRotaryEmbeddings( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + base=rope_base, + scaling_factor=rope_scaling_factor, + original_max_seq_len=original_max_seq_len, + beta_fast=beta_fast, + beta_slow=beta_slow, + mscale=mscale, + mscale_all_dim=mscale_all_dim, + ) + else: + rope = RotaryPositionalEmbeddings( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + base=rope_base, + ) + layers = [] + for i in range(num_layers): + q_head_dim = qk_rope_head_dim + qk_nope_head_dim + if q_lora_rank is None: + q_proj = nn.Linear(embed_dim, num_heads * q_head_dim, bias=False) + else: + q_proj = DeepSeekV3LatentLinear( + in_dim=embed_dim, + out_dim=num_heads * q_head_dim, + rank=q_lora_rank, + norm=RMSNorm(dim=q_lora_rank), + ) + self_attn = DeepSeekV3Attention( + embed_dim=embed_dim, + num_heads=num_heads, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + q_head_dim=q_head_dim, + q_proj=q_proj, + kv_proj=DeepSeekV3LatentLinear(in_dim=embed_dim, + out_dim=num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim), + rank=kv_lora_rank, + norm=RMSNorm(dim=kv_lora_rank), + rope_head_dim=qk_rope_head_dim), + output_proj=nn.Linear(num_heads * v_head_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + is_causal=True, + attn_dropout=0.0, + ) + is_moe = (moe_every_n_layers is None or (i + 1) % moe_every_n_layers == 0) and i >= first_moe_layer + if is_moe: + mlp_layer = DeepseekV3MoE( + experts=deepseek_v3_experts(num_experts, embed_dim, moe_hidden_dim), + router=DeepSeekV3TokenChoiceTopKRouter( + dim=embed_dim, + num_experts=num_experts, + experts_per_token=experts_per_token, + num_groups=num_groups, + topk_groups=topk_groups, + norm_topk_prob=norm_topk_prob, + routed_scaling_factor=routed_scaling_factor, + ), + shared_expert=deepseek_v3_mlp(embed_dim, moe_hidden_dim * num_shared_experts), + ) + else: + mlp_layer = deepseek_v3_mlp(embed_dim, mlp_hidden_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp_layer, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + + layers = nn.ModuleList(layers) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=embed_dim // num_heads, + norm=RMSNorm(dim=embed_dim, eps=norm_eps), + output=output_proj, + ) + +def deepseek_v3_experts( + num_experts: int, + dim: int, + hidden_dim: int, +) -> nn.ModuleDict: + experts = nn.ModuleDict({ + str(i): deepseek_v3_mlp(dim, hidden_dim) for i in range(num_experts) + }) + return experts + +def deepseek_v3_mlp( + dim: int, + hidden_dim: int +) -> FeedForward: + """ + Builds the FeedForward layer for DeepSeek V3. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + return FeedForward(gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj) diff --git a/torchtune/models/deepseek_v3/_convert_weights.py b/torchtune/models/deepseek_v3/_convert_weights.py new file mode 100644 index 0000000000..d4385e89ed --- /dev/null +++ b/torchtune/models/deepseek_v3/_convert_weights.py @@ -0,0 +1,88 @@ +import torch +from torchtune.models.convert_weights import get_mapped_key +import regex as re +from typing import Dict + +_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + + # attenion weights + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attn.q_proj.a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attn.q_proj.norm.scale", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attn.q_proj.b.weight", + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attn.kv_proj.a.weight", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attn.kv_proj.norm.scale", + "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attn.kv_proj.b.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + + # mlp non-expert weights + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + + # mlp MoE expert weights + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.mlp.experts.{}.w1.weight", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.mlp.experts.{}.w3.weight", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.mlp.experts.{}.w2.weight", + + # mlp MoE shared expert weights + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.mlp.shared_expert.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.mlp.shared_expert.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.mlp.shared_expert.w2.weight", + + # mlp MoE token router weights + "model.layers.{}.mlp.gate.weight": "layers.{}.mlp.router.gate", + "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.mlp.router.e_score_correction_bias", + + "lm_head.weight": "output.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, +} + + +def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: + try: + # Checks if there is a layer # in the key + if any(k.isdigit() for k in key.split(".")): + # Replace all numbers with "{}" to create key for lookup + abstract_key = re.sub(r"(\.\d+)", ".{}", key) + # Find all numbers in the key in order + layer_nums = re.findall(r"\d+", key) + new_key = mapping_dict[abstract_key] + # Format with all numbers + new_key = new_key.format(*layer_nums) + else: + new_key = mapping_dict[key] + except KeyError as e: + raise Exception( + f'Error converting the state dict. Found unexpected key: "{key}". ' + "Please make sure you're loading a checkpoint with the right format. " + ) from e + + return new_key + + +def deepseek_v3_hf_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + for key, value in state_dict.items(): + # Skip keys that should be ignored (like rotary embeddings) + if "rotary_emb.inv_freq" in key: + continue + + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict + + +def deepseek_v3_tune_to_hf(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items() + } + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/models/deepseek_v3/_experts.py b/torchtune/models/deepseek_v3/_experts.py new file mode 100644 index 0000000000..b27eda1f0e --- /dev/null +++ b/torchtune/models/deepseek_v3/_experts.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch import nn +from torch.nn import functional as F + + +class DeepseekV3GroupedExperts(nn.Module): + """This class implements the grouped experts layer used in Mixture of Experts. Each expert + is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. + + This class is identical to :class:`~torchtune.modules.moe.experts.GroupedExperts`, except that it uses a + `ModuleDict` to store the gate, down, and up projection matrices for each expert, rather than a + combined `nn.Parameter`. + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + num_experts (int): Number of experts in this grouped experts layer. Default is 1. + activation (Callable): Activation function to use. Default is F.silu. + """ + + def __init__( + self, + *, + dim: int, + hidden_dim: int, + num_experts: int = 1, + activation: Callable = F.silu, + ): + super().__init__() + self.dim = dim + self.num_experts = num_experts + self.experts = nn.ModuleDict({ + f"expert_{i}": nn.Linear(dim, hidden_dim) for i in range(num_experts) + }) + self.experts_down = nn.ModuleDict({ + f"expert_{i}": nn.Linear(hidden_dim, dim) for i in range(num_experts) + }) + self.experts_up = nn.ModuleDict({ + f"expert_{i}": nn.Linear(dim, hidden_dim) for i in range(num_experts) + }) + self.act_fn = activation + + # TODO: force no inference mode as a hack to get around + # "Cannot set version_counter for inference tensor" + @torch.inference_mode(mode=False) + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor with shape ``(bsz * seq_len * experts_per_token, dim)`` + num_tokens_per_expert (torch.Tensor): Tensor with shape ``(num_experts,)`` + enumerating the number of tokens each expert receives + + Returns: + torch.Tensor: tensor with shape (bsz * seq_len * experts_per_token, dim) + """ + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.gate_proj[expert_idx], + self.down_proj[expert_idx], + self.up_proj[expert_idx], + ) + h = self.act_fn(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + return out diff --git a/torchtune/models/deepseek_v3/_linear.py b/torchtune/models/deepseek_v3/_linear.py new file mode 100644 index 0000000000..df242ee04f --- /dev/null +++ b/torchtune/models/deepseek_v3/_linear.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from typing import Optional, Optional +from torchtune.modules import RMSNorm + + +class DeepSeekV3LatentLinear(nn.Module): + def __init__( + self, + *, + in_dim: int, + out_dim: int, + rank: int, + norm: nn.Module, + rope_head_dim: Optional[int] = None, + ): + super().__init__() + self.rank = rank + self.rope_head_dim = rope_head_dim or 0 + self.a = nn.Linear( + in_features=in_dim, out_features=rank + self.rope_head_dim, bias=False + ) + self.b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.norm = norm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s_x, _ = x.shape + out = self.a(x) + + if self.rope_head_dim: + out, rope_out = torch.split(out, [self.rank, self.rope_head_dim], dim=-1) + rope_out = rope_out.view(b, s_x, 1, self.rope_head_dim).transpose(1, 2) + out = self.b(self.norm(out)) + return out, rope_out + + return self.b(self.norm(out)) diff --git a/torchtune/models/deepseek_v3/_model_builders.py b/torchtune/models/deepseek_v3/_model_builders.py new file mode 100644 index 0000000000..76e410c0bd --- /dev/null +++ b/torchtune/models/deepseek_v3/_model_builders.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.deepseek_v3._component_builders import deepseek_v3 + + +def deepseek_v3_6B_64e(): + """ + Builder for a DeepSeek V3 6.1B model with 64 experts. + https://huggingface.co/smohammadi/deepseek-v3-micro + """ + return deepseek_v3( + vocab_size=129280, + num_layers=16, + num_heads=32, + embed_dim=2048, + max_seq_len=163840, + mlp_hidden_dim=5632, + rope_base=10000, + norm_eps=1e-6, + moe_every_n_layers=1, + first_moe_layer=3, + moe_hidden_dim=1024, + num_experts=64, + num_shared_experts=1, + experts_per_token=8, + num_groups=8, + topk_groups=4, + norm_topk_prob=True, + routed_scaling_factor=2.5, + q_lora_rank=256, + kv_lora_rank=128, + qk_rope_head_dim=64, + qk_nope_head_dim=128, + v_head_dim=128, + rope_scaling_factor=40.0, + original_max_seq_len=4096, + beta_fast=32.0, + beta_slow=1.0, + mscale=1.0, + mscale_all_dim=1.0, + ) + + +def moonlight_16B_64e(): + return deepseek_v3( + vocab_size=163840, + num_layers=27, + num_heads=16, + embed_dim=2048, + max_seq_len=8192, + mlp_hidden_dim=11264, + rope_base=50000, + norm_eps=1e-5, + moe_every_n_layers=1, + first_moe_layer=1, + moe_hidden_dim=1408, + num_experts=64, + num_shared_experts=2, + experts_per_token=6, + num_groups=1, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=2.446, + q_lora_rank=None, + kv_lora_rank=512, + qk_rope_head_dim=64, + qk_nope_head_dim=128, + v_head_dim=128, + ) + diff --git a/torchtune/models/deepseek_v3/_moe.py b/torchtune/models/deepseek_v3/_moe.py new file mode 100644 index 0000000000..6ab432c1c9 --- /dev/null +++ b/torchtune/models/deepseek_v3/_moe.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn.functional as F +import torch +from torch import nn +from typing import Optional + + +class DeepseekV3MoE(nn.Module): + """This class implements the Mixture of Experts (MoE) layer for DeepSeek V3. + This comprises a set of a router and a set of experts, which are typically smaller than MLP layers in standard + transformer models. The router is used to select a subset of experts for each token, and the selected experts are + then used to compute the output of the MoE layer. See more details in https://arxiv.org/2401.0606. + + Args: + experts (nn.Module): experts module. + router (nn.Module): router module. + shared_expert (Optional[nn.Module]): shared expert module. Default is None. + """ + + def __init__( + self, + *, + experts: nn.ModuleDict, + router: nn.Module, + shared_expert: Optional[nn.Module] = None, + ): + super().__init__() + self.experts = experts + self.router = router + self.shared_expert = shared_expert + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + + b, s, dim = x.shape + # top_scores and selected_indices shape (bs*slen*experts_per_token,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(b * s, dim)) + + token_indices = token_indices.unsqueeze(1).expand(-1, dim) + # shape (b*s*experts_per_token, dim) + + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (b*s*top_k, dim) + routed_input = torch.split(routed_input, split_size_or_sections=num_tokens_per_expert.tolist(), dim=0) + routed_output = [] + for expert_idx, x_expert in enumerate(routed_input): + if x_expert.numel() == 0: + routed_output.append(torch.zeros_like(x_expert)) + continue + routed_output.append(self.experts[str(expert_idx)](x_expert)) + + routed_output = torch.cat(routed_output, dim=0) + routed_output = routed_output * top_scores.unsqueeze(-1) + + out = torch.zeros_like(x.reshape(b * s, dim)).to(routed_output.dtype) + if routed_output.numel() > 0: + out.scatter_add_(dim=0, index=token_indices, src=routed_output) + + out = out.view(b, s, dim).to(x.dtype) + + if self.shared_expert is not None: + out = out + self.shared_expert(x) + + return out + + +class DeepSeekV3TokenChoiceTopKRouter(nn.Module): + def __init__(self, + dim: int, + num_experts: int, + experts_per_token: int, + num_groups: int, + topk_groups: int, + norm_topk_prob: bool, + routed_scaling_factor: float + ): + super().__init__() + self.dim = dim + self.num_experts = num_experts + self.experts_per_token = experts_per_token + self.num_groups = num_groups + self.topk_groups = topk_groups + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.e_score_correction_bias = nn.Parameter(torch.rand((self.num_experts))) + self.gate = nn.Parameter(torch.empty((num_experts, dim))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n = x.shape[0] + logits = F.linear(x.to(torch.float32), self.gate.to(torch.float32), None) + + # calculate scores for every expert in every group + scores = torch.sigmoid(logits) + scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + + # now calculate scores for every group based on the + # top 2 scores of experts within each group + experts_per_group = self.num_experts // self.num_groups + group_scores = ( + scores_for_choice.view(n, self.num_groups, experts_per_group) + .topk(2, dim=-1)[0].sum(dim=-1) + ) + + # grab the topk_groups number of groups based + # on the scores for each group calculated above + group_idxs = torch.topk(group_scores, k=self.topk_groups, dim=-1, sorted=False).indices + + # mask out all experts within groups which will not be considered + group_mask = torch.zeros_like(group_scores, dtype=torch.bool) + group_mask.scatter_(1, group_idxs, True) # [n, n_group] + + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + n, self.num_groups, experts_per_group + ) + .reshape(n, -1) + ) + # masked_scores = scores + self.e_score_correction_bias.unsqueeze(0) + masked_scores = scores_for_choice.masked_fill( + ~score_mask, float('-inf') + ) + # now select the top experts_per_token number of + # experts based on experts within eligible groups + _, selected_experts_idxs = torch.topk(masked_scores, k=self.experts_per_token, dim=-1, sorted=False) + scores_per_token = scores.gather(1, selected_experts_idxs) + + # normalize scores + if self.num_experts > 1 and self.norm_topk_prob: + denominator = scores_per_token.sum(dim=-1, keepdim=True) + 1e-20 + scores_per_token /= denominator + + # apply scaling factor + scores_per_token = scores_per_token * self.routed_scaling_factor + + num_tokens_per_expert = torch.histc( + selected_experts_idxs.float(), bins=self.num_experts, min=0, max=self.num_experts - 1 + ).to(torch.int32) + + token_idxs_experts_sorted = torch.argsort( + selected_experts_idxs.view(-1), stable=True + ) + + scores_per_expert = scores_per_token.view(-1)[token_idxs_experts_sorted] + token_idxs_experts_sorted = ( + token_idxs_experts_sorted // self.experts_per_token + ) + return scores_per_expert, token_idxs_experts_sorted, num_tokens_per_expert diff --git a/torchtune/models/deepseek_v3/_position_embeddings.py b/torchtune/models/deepseek_v3/_position_embeddings.py new file mode 100644 index 0000000000..ba0925c73a --- /dev/null +++ b/torchtune/models/deepseek_v3/_position_embeddings.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +from torch import nn + + +class DeepSeekV3YarnRotaryEmbeddings(nn.Module): + """ + This class implements YaRN (Yet another RoPE extensioN) Rotary Positional Embeddings + for DeepSeek v3, proposed in https://arxiv.org/abs/2309.00071. + + YaRN extends RoPE to longer sequence lengths by selectively applying frequency scaling + to different parts of the frequency spectrum based on wavelength characteristics. + It also includes magnitude scaling to preserve attention patterns. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim // num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + scaling_factor (float): Factor by which to scale the original context length + original_max_seq_len (int): Original maximum sequence length before scaling + beta_fast (float): Lower bound for frequency scaling range. Default: 32 + beta_slow (float): Upper bound for frequency scaling range. Default: 1 + mscale (float): Magnitude scaling factor. Default: 1 + mscale_all_dim (float): Magnitude scaling for all dimensions. Default: 0 + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + scaling_factor: float = 1.0, + original_max_seq_len: int = 4096, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + mscale: float = 1.0, + mscale_all_dim: float = 1.0, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.scaling_factor = scaling_factor + self.original_max_seq_len = original_max_seq_len + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + self.rope_init() + + def _find_correction_dim( + self, num_rotations: float, dim: int, base: int, max_position_embeddings: int + ) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _find_correction_range( + self, low_rot: float, high_rot: float, dim: int, base: int, max_position_embeddings: int + ) -> tuple[int, int]: + low = math.floor( + self._find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + def get_mscale(self, scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def _get_linear_ramp_mask(self, min_val: int, max_val: int, dim: int) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def rope_init(self): + # Compute base extrapolated freqs + freq_base = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + + # Compute scaled interpolated freqs + freq_interp = 1.0 / ( + self.scaling_factor + * self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + + # Find correction range for frequency interpolation + low, high = self._find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.original_max_seq_len, + ) + + # Create interpolation mask + inv_freq_mask = 1.0 - self._get_linear_ramp_mask(low, high, self.dim // 2) + + # Interpolate between scaled and unscaled frequencies + theta = freq_interp * (1 - inv_freq_mask) + freq_base * inv_freq_mask + + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # Calculate magnitude scaling + mscale = float( + self.get_mscale(self.scaling_factor, self.mscale) + / self.get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([idx_theta.cos() * mscale, idx_theta.sin() * mscale], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) diff --git a/torchtune/models/deepseek_v3/_tokenizer.py b/torchtune/models/deepseek_v3/_tokenizer.py new file mode 100644 index 0000000000..768520978c --- /dev/null +++ b/torchtune/models/deepseek_v3/_tokenizer.py @@ -0,0 +1,55 @@ +from typing import Optional +from torchtune.modules.transforms.tokenizers import HuggingFaceBaseTokenizer, ModelTokenizer +from torchtune.modules.transforms import Transform +from functools import cached_property + +class DeepSeekV3Tokenizer(ModelTokenizer, Transform): + + def __init__(self, + path: str, + config_path: str, + max_seq_len: Optional[int] = None, + ) -> None: + self.hf_tokenizer = HuggingFaceBaseTokenizer(path, tokenizer_config_json_path=config_path) + self.max_seq_len = max_seq_len + + @property + def vocab_size(self) -> int: + return self.hf_tokenizer.get_vocab_size() + + def encode(self, *args, **kwargs) -> list[int]: + return self.hf_tokenizer.encode(*args, **kwargs) + + def decode(self, *args, **kwargs) -> str: + return self.hf_tokenizer.decode(*args, **kwargs) + + @property + def bos_id(self) -> int: + return self.hf_tokenizer.bos_id + + @property + def eos_id(self) -> int: + return self.hf_tokenizer.eos_id + + @cached_property + def pad_id(self) -> int: + return self.hf_tokenizer.tokenizer.token_to_id(self.hf_tokenizer.config.get("pad_token")) + + +if __name__ == "__main__": + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("smohammadi/deepseek-v3-micro") + text = "Hello, world!" + tokens = tokenizer.encode(text, add_special_tokens=True) + print(tokens) + print(tokenizer.decode(tokens)) + + tt_tokenizer = DeepSeekV3Tokenizer( + path="/Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json", + config_path="/Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json", + max_seq_len=1024 + ) + tt_tokens = tt_tokenizer.encode(text, add_bos=True, add_eos=True) + print(tt_tokens) + print(tt_tokenizer.decode(tt_tokens)) + import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/torchtune/models/qwen2/__init__.py b/torchtune/models/qwen2/__init__.py index 8e04fba85d..169a67dab4 100644 --- a/torchtune/models/qwen2/__init__.py +++ b/torchtune/models/qwen2/__init__.py @@ -15,7 +15,7 @@ qwen2_7b, qwen2_tokenizer, ) -from ._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from ._position_embeddings import Qwen2RotaryPositionalEmbeddings from ._tokenizer import Qwen2Tokenizer __all__ = [ diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 45e0cb1d60..5ce667f0ed 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -10,7 +10,7 @@ from torch import nn from torchtune.modules.transformer import TransformerDecoder -from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from torchtune.models.qwen2._position_embeddings import Qwen2RotaryPositionalEmbeddings from torchtune.modules import ( MultiHeadAttention, diff --git a/torchtune/models/qwen2/_positional_embeddings.py b/torchtune/models/qwen2/_position_embeddings.py similarity index 98% rename from torchtune/models/qwen2/_positional_embeddings.py rename to torchtune/models/qwen2/_position_embeddings.py index 61e8682783..48b74ba8e4 100644 --- a/torchtune/models/qwen2/_positional_embeddings.py +++ b/torchtune/models/qwen2/_position_embeddings.py @@ -49,7 +49,7 @@ def rope_init(self): self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: - # Create position indexes `[0, 1, ..., max_seq_len - 1]` + # Create position indexes [0, 1, ..., max_seq_len - 1] seq_idx = torch.arange( max_seq_len, dtype=self.theta.dtype, device=self.theta.device ) diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index f2f4985029..04919248f1 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -53,8 +53,9 @@ def compile_friendly_flex_attention( k: torch.Tensor, v: torch.Tensor, block_mask: BlockMask, + scale: float, ) -> torch.Tensor: - return flex_attention_compiled(q, k, v, block_mask=block_mask) + return flex_attention_compiled(q, k, v, block_mask=block_mask, scale=scale) _MaskType = Union[torch.Tensor, BlockMask] else: @@ -201,6 +202,7 @@ def _attention_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # Flex attention uses the BlockMask # (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168) @@ -223,6 +225,7 @@ def _attention_call( k, v, block_mask=mask, + scale=scale, ) # If mask is a standard boolean tensor or None, then use SDPA else: @@ -238,6 +241,7 @@ def _attention_call( attn_mask=mask, dropout_p=dropout_p, is_causal=is_causal, + scale=scale, ) else: @@ -249,6 +253,7 @@ def _attention_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # shape: [b, 1, s, s] if mask is not None: @@ -262,6 +267,7 @@ def _attention_call( attn_mask=mask, dropout_p=dropout_p, is_causal=is_causal, + scale=scale, ) return _attention_call diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index b6fd008356..15b83dd13a 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -109,7 +109,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. - Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 320224cabb..b52692dd4b 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -673,6 +673,12 @@ def load_checkpoint(self) -> dict[str, Any]: converted_state_dict[training.MODEL_KEY] = llama4_hf_to_tune( merged_state_dict, ) + elif self._model_type == ModelType.DEEPSEEK_V3: + from torchtune.models.deepseek_v3._convert_weights import deepseek_v3_hf_to_tune + + converted_state_dict[training.MODEL_KEY] = deepseek_v3_hf_to_tune( + merged_state_dict, + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -802,6 +808,12 @@ def save_checkpoint( state_dict[training.MODEL_KEY] = llama4_tune_to_hf( state_dict[training.MODEL_KEY], ) + elif self._model_type == ModelType.DEEPSEEK_V3: + from torchtune.models.deepseek_v3._convert_weights import deepseek_v3_tune_to_hf + + state_dict[training.MODEL_KEY] = deepseek_v3_tune_to_hf( + state_dict[training.MODEL_KEY], + ) else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 1dde03a121..2c6c32e859 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -110,6 +110,7 @@ class ModelType(Enum): >>> state_dict = my_custom_state_dict_mapping(state_dict) """ + DEEPSEEK_V3: str = "deepseek_v3" GEMMA: str = "gemma" GEMMA2: str = "gemma2" LLAMA2: str = "llama2"