Skip to content

[WIP] DSV3 #2764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions recipes/configs/deepseek_v3/6B_64e_full_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Qwen2.5 7B
#
# This config assumes that you've run the following command before launching
# this run:
# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
_component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer
path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json
config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json
max_seq_len: 1024

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: openai/gsm8k
column: question
name: main
split: train
packed: False # True increases speed
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.deepseek_v3._model_builders.deepseek_v3_6B_64e

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro
checkpoint_files: [
model-00001-of-00003.safetensors,
model-00002-of-00003.safetensors,
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: DEEPSEEK_V3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
lr: 5e-6
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: mps

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: fp32

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.StdoutLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO # DEBUG, WARN, etc.


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
113 changes: 113 additions & 0 deletions recipes/configs/deepseek_v3/moonlight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Qwen2.5 7B
#
# This config assumes that you've run the following command before launching
# this run:
# tune download Qwen/Qwen2.5-7B-Instruct --output-dir /tmp/Qwen2.5-7B-Instruct
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config qwen2_5/7B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: /tmp/dsv3 # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
_component_: torchtune.models.deepseek_v3._tokenizer.DeepSeekV3Tokenizer
path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer.json
config_path: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/deepseek-v3-micro/tokenizer_config.json
max_seq_len: 1024

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: openai/gsm8k
column: question
name: main
split: train
packed: False # True increases speed
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.deepseek_v3._model_builders.moonlight_16B_64e

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /Users/salmanmohammadi/projects/torchtune/target/scripts/deepseek/moonshot
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "27"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: DEEPSEEK_V3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
lr: 5e-6
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: mps

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.StdoutLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO # DEBUG, WARN, etc.


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
5 changes: 5 additions & 0 deletions torchtune/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
101 changes: 101 additions & 0 deletions torchtune/models/deepseek_v3/_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from typing import Optional
from torchtune.modules.attention_utils import _MaskType
from torchtune.modules import RMSNorm
from torchtune.modules.attention import _sdpa_or_flex_attention


class DeepSeekV3Attention(nn.Module):
def __init__(self,
embed_dim: int,
num_heads: int,
qk_rope_head_dim: int,
v_head_dim: int,
qk_nope_head_dim: int,
q_head_dim: int,
q_proj: nn.Module,
kv_proj: nn.Module,
output_proj: nn.Module,
pos_embeddings: nn.Module,
max_seq_len: int = 4096,
is_causal: bool = True,
attn_dropout: float = 0.0,):
super().__init__()
self.num_heads = num_heads
self.embed_dim = embed_dim
self.attn_dropout = attn_dropout
self.q_head_dim = q_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.max_seq_len = max_seq_len
self.is_causal = is_causal

# Set layers
self.q_proj = q_proj
self.kv_proj = kv_proj
self.output_proj = output_proj
self.pos_embeddings = pos_embeddings
self.softmax_scale = self.q_head_dim ** (-0.5)
if hasattr(self.pos_embeddings, "get_mscale"):
mscale = self.pos_embeddings.get_mscale(self.pos_embeddings.scaling_factor, self.pos_embeddings.mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
self.cache_enabled = False

self._attention_call = _sdpa_or_flex_attention()

def forward(
self,
x: torch.Tensor,
y: Optional[torch.Tensor] = None,
*,
mask: Optional[_MaskType] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b, s_x, _ = x.shape
q = self.q_proj(x)
q = q.view(b, s_x, self.num_heads, self.q_head_dim)
q = q.transpose(1, 2)

q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

kv, k_pe = self.kv_proj(x)
kv = kv.view(b, s_x, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
kv = kv.transpose(1, 2)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

q_pe = self.pos_embeddings(q_pe, input_pos=input_pos)
k_pe = self.pos_embeddings(k_pe, input_pos=input_pos)

query_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe

key_states = k_pe.new_empty(b, self.num_heads, s_x, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe

output = self._attention_call(
query_states,
key_states,
value_states,
mask=mask,
dropout_p=self.attn_dropout if self.training else 0.0,
is_causal=mask is None,
scale=self.softmax_scale,
)

# reshape the output to be the same shape as the input
output = output.transpose(1, 2).contiguous().view(b, s_x, -1)

return self.output_proj(output)
Loading