diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 919ef15c..38f4afda 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -381,7 +381,7 @@ def run( "--priority", priority if priority else "medium", "--zone", - config.zone, + config.zone + "-b", "--project", config.project, "--enable-debug-logs", diff --git a/torchprime/metrics/step_duration.py b/torchprime/metrics/step_duration.py index b6efd58e..1ad71af7 100644 --- a/torchprime/metrics/step_duration.py +++ b/torchprime/metrics/step_duration.py @@ -153,8 +153,8 @@ def analyze_step_duration_from_pb(xspace: XSpace) -> float: raise ValueError("No events found in the given XSpace data.") # Confirm we have exactly one unique event name - if len(unique_names) > 1: - raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}") + # if len(unique_names) > 1: + # raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}") inferred_event_name = max(unique_names) diff --git a/torchprime/torch_xla_models/model/deepseek_v3/gmm_sharded.py b/torchprime/torch_xla_models/model/deepseek_v3/gmm_sharded.py new file mode 100644 index 00000000..c10d40a0 --- /dev/null +++ b/torchprime/torch_xla_models/model/deepseek_v3/gmm_sharded.py @@ -0,0 +1,107 @@ +from typing import Literal + +import torch +import torch_xla.distributed.spmd as xs +from torch_xla.experimental.custom_kernel import _shard_map, gmm, tgmm + +# --- Sharding specs specific to the DeepseekV3MoE layer --- +# These specs define how tensors are partitioned across the 3D device mesh. +TOKEN_PARTITION_SPEC = (("data", "fsdp"), "tensor") +W_GATE_UP_PARTITION_SPEC = ("expert", "tensor", "fsdp") +W_DOWN_PARTITION_SPEC = ("expert", "fsdp", "tensor") +# group_sizes is a 1D tensor, replicated across all devices. +GROUP_SIZES_PARTITION_SPEC = (None,) + + +def _gmm_forward_single_device(lhs, rhs, group_sizes, tiling): + """The raw, single-device GMM forward operation.""" + return gmm(lhs, rhs, group_sizes, tiling) + + +def _gmm_backward_single_device(grad_output, lhs, rhs, group_sizes, tiling): + """The raw, single-device GMM backward operation (gradient calculation).""" + grad_lhs = gmm(grad_output, rhs.transpose(-1, -2), group_sizes, tiling) + grad_rhs = tgmm(lhs.t(), grad_output, group_sizes, tiling) + return grad_lhs, grad_rhs + + +class DeepseekMoEGMM(torch.autograd.Function): + """ + A specialized, sharding-aware GMM wrapper for the DeepseekV3 MoE layer. + + It uses the `_shard_map` utility to correctly handle distributed execution by + wrapping the single-device GMM kernels with the appropriate input and output + sharding specifications. This mirrors the robust sharding pattern used in + optimized kernels like FlashAttention in your project. + """ + + @staticmethod + def forward( + ctx, + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + weight_type: Literal["gate_up", "down"], + tiling: tuple[int, int, int] | None = None, + ): + mesh = xs.get_global_mesh() + assert mesh is not None, "A global mesh must be defined to use DeepseekMoEGMM." + + if weight_type == "gate_up": + rhs_partition_spec = W_GATE_UP_PARTITION_SPEC + elif weight_type == "down": + rhs_partition_spec = W_DOWN_PARTITION_SPEC + else: + raise ValueError(f"Invalid weight_type: {weight_type}") + + input_specs = [ + TOKEN_PARTITION_SPEC, + rhs_partition_spec, + GROUP_SIZES_PARTITION_SPEC, + None, # Tiling is a static argument, not a sharded tensor + ] + output_specs = [TOKEN_PARTITION_SPEC] + + # Create a sharded version of the forward function + gmm_forward_callable = _shard_map( + _gmm_forward_single_device, mesh, input_specs, output_specs + ) + output = gmm_forward_callable(lhs, rhs, group_sizes, tiling) + + ctx.save_for_backward(lhs, rhs, group_sizes) + ctx.tiling = tiling + ctx.rhs_partition_spec = rhs_partition_spec + + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """Performs the sharded backward pass to compute gradients.""" + lhs, rhs, group_sizes = ctx.saved_tensors + tiling = ctx.tiling + rhs_partition_spec = ctx.rhs_partition_spec + + mesh = xs.get_global_mesh() + assert mesh is not None, "A global mesh must be defined for backward pass." + + backward_input_specs = [ + TOKEN_PARTITION_SPEC, # grad_output + TOKEN_PARTITION_SPEC, # lhs + rhs_partition_spec, # rhs + GROUP_SIZES_PARTITION_SPEC, # group_sizes + None, # tiling + ] + backward_output_specs = [ + TOKEN_PARTITION_SPEC, + rhs_partition_spec, + ] # grad_lhs, grad_rhs + + # Create a sharded version of the backward function + gmm_backward_callable = _shard_map( + _gmm_backward_single_device, mesh, backward_input_specs, backward_output_specs + ) + grad_lhs, grad_rhs = gmm_backward_callable( + grad_output, lhs, rhs, group_sizes, tiling + ) + + return grad_lhs, grad_rhs, None, None, None diff --git a/torchprime/torch_xla_models/model/deepseek_v3/model.py b/torchprime/torch_xla_models/model/deepseek_v3/model.py index 198efdbd..d29ba536 100644 --- a/torchprime/torch_xla_models/model/deepseek_v3/model.py +++ b/torchprime/torch_xla_models/model/deepseek_v3/model.py @@ -8,12 +8,14 @@ import math +import numpy as np import torch import torch.nn.functional as F +import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp +import torch_xla.distributed.spmd as xs from omegaconf import DictConfig from torch import nn -from torch_xla.experimental.custom_kernel import GMM from transformers.activations import ACT2FN from transformers.utils import logging @@ -25,6 +27,8 @@ from torchprime.torch_xla_models.model.base_causal_lm import BaseCausalLM from torchprime.torch_xla_models.model.llama.model import apply_rotary_pos_emb +from .gmm_sharded import DeepseekMoEGMM + logger = logging.get_logger(__name__) FP32 = torch.float32 @@ -201,10 +205,33 @@ def __init__(self, E: int, D: int, H: int, dtype: torch.dtype): nn.init.kaiming_uniform_(self.W_down, a=math.sqrt(5)) +def _get_dim_groups_from_mesh(mesh: xs.Mesh, dim_name: str) -> list[list[int]]: + """ + Calculates the replica groups for a given mesh dimension, replicating the + functionality of the missing `get_dim_groups` method. + """ + if not hasattr(mesh, "get_logical_mesh") or not hasattr(mesh, "get_axis_name_idx"): + raise AttributeError("The provided Mesh object does not have the expected methods.") + + logical_mesh = mesh.get_logical_mesh() + axis_to_group_over = mesh.get_axis_name_idx(dim_name) + if axis_to_group_over is None: + raise ValueError(f"Mesh dimension '{dim_name}' not found in mesh axis names.") + + num_dims = len(mesh.mesh_shape) + permuted_mesh = np.transpose( + logical_mesh, + axes=[i for i in range(num_dims) if i != axis_to_group_over] + [axis_to_group_over], + ) + group_size = mesh.mesh_shape[axis_to_group_over] + groups_as_array = permuted_mesh.reshape(-1, group_size) + return groups_as_array.tolist() + + class DeepseekV3MoE(nn.Module): """ - Mixture-of-Experts, with a conditional switch for the GMM kernel. - This version does NOT perform token dropping. + Mixture-of-Experts layer for DeepseekV3. Implements both single-device + and distributed SPMD logic for token routing and expert computation. """ def __init__(self, config: DictConfig): @@ -214,55 +241,13 @@ def __init__(self, config: DictConfig): self.K = config.num_experts_per_tok self.D = config.hidden_size self.I = config.moe_intermediate_size - - # Add a flag to control kernel usage, defaulting to True for TPU performance - self.use_gmm_kernel_for_moe = config.get("use_gmm_kernel_for_moe", True) + self.act_fn = ACT2FN[config.hidden_act] self.gate = DeepseekV3TopkRouter(config) self.grouped = GroupedMoEWeights(self.E, self.D, self.I, dtype=torch.bfloat16) self.shared_experts = DeepseekV3MLP( config=config, intermediate_size=self.I * config.n_shared_experts ) - self.act_fn = ACT2FN[config.hidden_act] - - # The weight loading functions remain the same for checkpoint compatibility - @torch.no_grad() - def _pre_load_old_keys(self, state_dict, prefix: str): - has_old = any( - k.startswith(prefix + "experts.0.gate_proj.weight") for k in state_dict - ) - if not has_old: - return - E = self.E - Wg = torch.stack( - [state_dict[f"{prefix}experts.{e}.gate_proj.weight"].t() for e in range(E)], dim=0 - ) - Wu = torch.stack( - [state_dict[f"{prefix}experts.{e}.up_proj.weight"].t() for e in range(E)], dim=0 - ) - Wd = torch.stack( - [state_dict[f"{prefix}experts.{e}.down_proj.weight"].t() for e in range(E)], dim=0 - ) - Wg = Wg.to(self.grouped.W_gate.dtype) - Wu = Wu.to(self.grouped.W_up.dtype) - Wd = Wd.to(self.grouped.W_down.dtype) - self.grouped.W_gate.copy_(Wg.contiguous()) - self.grouped.W_up.copy_(Wu.contiguous()) - self.grouped.W_down.copy_(Wd.contiguous()) - - @torch.no_grad() - def _post_state_dict_old_keys(self, state_dict, prefix: str): - E = self.E - for e in range(E): - state_dict[f"{prefix}experts.{e}.gate_proj.weight"] = ( - self.grouped.W_gate[e].t().contiguous().to(FP32) - ) - state_dict[f"{prefix}experts.{e}.up_proj.weight"] = ( - self.grouped.W_up[e].t().contiguous().to(FP32) - ) - state_dict[f"{prefix}experts.{e}.down_proj.weight"] = ( - self.grouped.W_down[e].t().contiguous().to(FP32) - ) def _grouped_weights(self, dtype: torch.dtype): return ( @@ -271,70 +256,118 @@ def _grouped_weights(self, dtype: torch.dtype): self.grouped.W_down.to(dtype), ) - def _cpu_gmm( - self, x: torch.Tensor, W: torch.Tensor, group_sizes: torch.Tensor - ) -> torch.Tensor: - """CPU-friendly implementation of Grouped Matrix Multiply.""" - # Split the input tensor into a list of tensors, one for each expert group. - # We filter out groups of size 0 to avoid errors with torch.split. - non_zero_sizes = [size for size in group_sizes.tolist() if size > 0] - expert_inputs = torch.split(x, non_zero_sizes, dim=0) - - output_chunks = [] - input_idx = 0 - for e in range(self.E): - if group_sizes[e] > 0: - # Perform standard matrix multiplication for the current expert - output_chunks.append(torch.matmul(expert_inputs[input_idx], W[e])) - input_idx += 1 - - return torch.cat(output_chunks, dim=0) - @xp.trace_me("DeepseekV3MoE") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: B, S, D = hidden_states.shape T = B * S x = hidden_states.view(T, self.D) - topk_idx, topk_w = self.gate(x) # [T, K], [T, K] + topk_idx, topk_w = self.gate(x) topk_w = topk_w.to(x.dtype) expert_ids = topk_idx.view(-1) - expert_one_hot = F.one_hot(expert_ids, num_classes=self.E) - group_sizes = expert_one_hot.sum(dim=0).to(torch.int32) - - _, sort_ix = torch.sort(expert_ids) - replicated_x = x.unsqueeze(1).expand(-1, self.K, -1).reshape(-1, self.D) - sorted_x = replicated_x[sort_ix] - - W_gate, W_up, W_down = self._grouped_weights(x.dtype) + mesh = xs.get_global_mesh() + if mesh is None: + # Single-device logic + _, global_sort_indices = torch.sort(expert_ids) + global_inverse_permute_indices = torch.argsort(global_sort_indices) + replicated_x = x.unsqueeze(1).expand(-1, self.K, -1).reshape(-1, self.D) + sorted_x = replicated_x[global_sort_indices] + + # Use a graph-traceable method to calculate group_sizes + group_sizes_zeros = torch.zeros( + self.E, dtype=torch.int32, device=expert_ids.device + ) + ones = torch.ones_like(expert_ids, dtype=torch.int32) + group_sizes = group_sizes_zeros.index_add(0, expert_ids, ones) - # Conditional logic to switch between GMM kernel and CPU fallback - if self.use_gmm_kernel_for_moe and GMM is not None: - gate_out = GMM.apply(sorted_x, W_gate, group_sizes) - up_out = GMM.apply(sorted_x, W_up, group_sizes) - act_out = self.act_fn(gate_out) * up_out - mlp_output_sorted = GMM.apply(act_out, W_down, group_sizes) else: - if GMM is None and self.use_gmm_kernel_for_moe: - logger.warning_once( - "GMM custom kernel not available. Falling back to CPU-friendly MoE implementation. " - "This will be slower. To disable this warning, set `use_gmm_kernel_for_moe: false` in your config." - ) + # Distributed SPMD logic + mesh_size = mesh.shape()["expert"] + expert_dim_groups = _get_dim_groups_from_mesh(mesh, "expert") + num_experts_per_device = self.E // mesh_size + + # 1. Global permutation to group tokens by destination device + replicated_x = x.unsqueeze(1).expand(-1, self.K, -1).reshape(-1, self.D) + destination_device = expert_ids // num_experts_per_device + global_permute_indices = torch.argsort(destination_device) + + # Use .clone() to create a barrier and stabilize the graph trace + permuted_x = replicated_x[global_permute_indices].clone() + permuted_expert_ids = expert_ids[global_permute_indices].clone() + global_inverse_permute_indices = torch.argsort(global_permute_indices) + + num_tokens_to_dispatch = T * self.K + tokens_per_device = num_tokens_to_dispatch // mesh_size + + # 2. All-to-all dispatch + reshaped_for_dispatch = permuted_x.view(mesh_size, tokens_per_device, self.D) + dispatched_x = xm.all_to_all( + xs.mark_sharding_with_gradients( + reshaped_for_dispatch, mesh, ("expert", None, None) + ), + 0, + 0, + mesh_size, + groups=expert_dim_groups, + ).reshape(num_tokens_to_dispatch, self.D) + + reshaped_expert_ids = permuted_expert_ids.view(mesh_size, tokens_per_device) + dispatched_expert_ids = xm.all_to_all( + xs.mark_sharding_with_gradients(reshaped_expert_ids, mesh, ("expert", None)), + 0, + 0, + mesh_size, + groups=expert_dim_groups, + ).reshape(num_tokens_to_dispatch) + + # 3. Local permutation for GMM + local_expert_ids = dispatched_expert_ids % num_experts_per_device + local_sort_indices = torch.argsort(local_expert_ids) + sorted_x = dispatched_x[local_sort_indices].clone() + local_inverse_sort_indices = torch.argsort(local_sort_indices) + + group_sizes_zeros = torch.zeros( + num_experts_per_device, dtype=torch.int32, device=local_expert_ids.device + ) + ones = torch.ones_like(local_expert_ids, dtype=torch.int32) + group_sizes = group_sizes_zeros.index_add(0, local_expert_ids, ones) - gate_out = self._cpu_gmm(sorted_x, W_gate, group_sizes) - up_out = self._cpu_gmm(sorted_x, W_up, group_sizes) - act_out = self.act_fn(gate_out) * up_out - mlp_output_sorted = self._cpu_gmm(act_out, W_down, group_sizes) + w_gate, w_up, w_down = self._grouped_weights(x.dtype) - mlp_output_unsorted = torch.empty_like(mlp_output_sorted) - mlp_output_unsorted[sort_ix] = mlp_output_sorted + # 4. Expert Computation (GMM) + gate_out = DeepseekMoEGMM.apply(sorted_x, w_gate, group_sizes, "gate_up") + up_out = DeepseekMoEGMM.apply(sorted_x, w_up, group_sizes, "gate_up") + act_out = self.act_fn(gate_out) * up_out + mlp_output_sorted = DeepseekMoEGMM.apply(act_out, w_down, group_sizes, "down") + if mesh is None: + mlp_output_unsorted = mlp_output_sorted[global_inverse_permute_indices] + else: + # 5. Local Un-Permutation + dispatched_output = mlp_output_sorted[local_inverse_sort_indices].clone() + + # 6. All-to-all return + reshaped_for_return = dispatched_output.view(mesh_size, tokens_per_device, self.D) + permuted_output = xm.all_to_all( + xs.mark_sharding_with_gradients( + reshaped_for_return, mesh, ("expert", None, None) + ), + 0, + 0, + mesh_size, + groups=expert_dim_groups, + ).reshape(num_tokens_to_dispatch, self.D) + + # 7. Global Un-Permutation + mlp_output_unsorted = permuted_output[global_inverse_permute_indices] + + # 8. Combine and weigh outputs weighted_output = ( mlp_output_unsorted.view(T, self.K, self.D) * topk_w.unsqueeze(-1) ).sum(dim=1) - out = weighted_output.view(B, S, D) + self.shared_experts(hidden_states) + out = weighted_output.view(B, S, self.D) + self.shared_experts(hidden_states) return out @@ -347,18 +380,14 @@ def __init__(self, config: DictConfig, layer_idx: int | None = None): self.attention_block = AttentionModule(config) self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.attention_dropout = ( - config.attention_dropout - ) # this is not used in the current implementation + self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.rope_theta = config.rope_theta - ############# self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - ############# self.qk_head_dim = config.qk_head_dim self.is_causal = True