From a6f1f5521d3206dc39791ebc549c47674133f056 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 15 Feb 2026 14:42:13 +0800 Subject: [PATCH 1/4] refactor: add name-based routing & slice mode for HybridMuon opt --- deepmd/pt/optimizer/hybrid_muon.py | 577 ++++++++++++++++------------ deepmd/pt/train/training.py | 6 +- deepmd/utils/argcheck.py | 40 +- source/tests/pt/test_hybrid_muon.py | 160 ++++++-- 4 files changed, 496 insertions(+), 287 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index 1083b30107..ecd8a4e0c3 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -3,18 +3,29 @@ HybridMuon optimizer for DeePMD-kit PyTorch backend. HybridMuon is a hybrid optimizer that automatically combines Muon and Adam. -Routing is controlled by parameter dimensionality and ``muon_2d_only``: - -- 1D parameters (biases, norms): Adam (no weight decay). -- When ``muon_2d_only=True`` (default): - - 2D parameters: Muon if ``min(m, n) >= min_2d_dim``, else Adam fallback. - - >2D parameters: Adam. -- When ``muon_2d_only=False``: - - >=2D parameters use matrix-view routing: - Muon if ``min(m, n) >= min_2d_dim``, else Adam fallback. - -For matrix-view routing, any parameter with ndim >= 2 is reshaped as: -``(rows, cols) = (numel // shape[-1], shape[-1])``. +Routing is controlled by parameter dimensionality, parameter names, and +``muon_mode``: + +- Parameters whose final effective name segment contains ``bias`` + (case-insensitive), or starts with ``adam_`` (case-insensitive): Adam. +- Parameters whose final effective name segment starts with ``adamw_`` + (case-insensitive): Adam with decoupled weight decay (AdamW-style). + The final effective segment means the last non-numeric segment in the full + parameter path (split by ``"."``), so trailing ParameterList indices are + ignored. +- 1D parameters (biases, norms, scales): Adam (no weight decay). +- ``muon_mode="2d"``: + - Matrix parameters with effective rank 2 (after dropping singleton dims) + use Muon. + - Effective rank >2 parameters use Adam with decoupled weight decay fallback. +- ``muon_mode="flat"``: + - >=2D matrix parameters use flattened matrix-view routing: + ``(rows, cols) = (prod(effective_shape[:-1]), effective_shape[-1])``. +- ``muon_mode="slice"`` (default): + - Effective rank 2 matrix parameters: same as ``"2d"``. + - Effective rank >=3 matrix parameters: treat leading axes as batch and apply Muon + independently on each ``(..., m, n)`` slice (no cross-slice mixing). + - Routing shape is computed on effective shape (singleton dims removed). This is different from PyTorch's torch.optim.Muon, which ONLY supports 2D parameters and requires manual configuration of AdamW for 1D parameters. HybridMuon provides @@ -79,6 +90,7 @@ if TYPE_CHECKING: from collections.abc import ( + Callable, Iterable, ) @@ -295,59 +307,195 @@ def _newton_schulz_orth( return X -def should_fallback_to_adam_for_matrix( - p: torch.Tensor, - min_2d_dim: int, -) -> bool: +def _batched_newton_schulz_orth( + G: torch.Tensor, +) -> torch.Tensor: """ - Check if a parameter should fallback to Adam based on matrix-view dimensions. + Orthogonalize a batch of matrices via quintic Newton-Schulz iteration. Parameters ---------- - p : torch.Tensor - Parameter tensor with ndim >= 2. - min_2d_dim : int - Minimum min(m, n) threshold for Muon. Matrices with min(m, n) >= - min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam. + G : torch.Tensor + Input tensor with shape (B, m, n), where B is batch size. Returns ------- - bool - True if min(m, n) < min_2d_dim, False otherwise. + torch.Tensor + Orthogonalized tensor in bfloat16 with shape (B, m, n). + """ + # === Step 1. Validate and prepare matrix orientation === + if G.ndim != 3: + raise ValueError( + "Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)." + ) + + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) + + # === Step 2. Normalize each slice by Frobenius norm === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) + + # === Step 3. Batched Newton-Schulz iterations === + for _ in range(NS_STEPS): + A = torch.bmm(X, X.transpose(-2, -1)) + gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) + X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) + + # === Step 4. Restore original orientation === + if transposed: + X = X.transpose(-2, -1) + + return X - Raises - ------ - ValueError - If tensor has ndim < 2. + +def get_adam_route( + param_name: str | None, +) -> str: """ - # === Step 1. Validate === - if p.ndim < 2: - raise ValueError("Parameter must have ndim >= 2 for Muon suitability check.") + Determine the optimizer route for a parameter based on its name. - # === Step 2. Derive matrix shape consistent with Muon reshape === - # Flatten all leading axes into rows and keep the last axis as cols. - # This preserves the "input-channel axis" as the NS orthogonalization space - # for N-D linear weights (e.g., (..., C_out, C_in) -> (-1, C_in)). - m = int(p.numel() // p.shape[-1]) - n = int(p.shape[-1]) + Parameters + ---------- + param_name : str | None + Parameter name. If None, fallback behavior treats parameter as + matrix (Muon-eligible). - # === Step 3. Check if any dimension too small for Muon === - return min(m, n) < min_2d_dim + Returns + ------- + str + ``"muon"`` if this parameter is eligible as matrix weight by name, + ``"adam"`` for Adam path (no weight decay), + ``"adamw"`` for AdamW path (decoupled weight decay). + + Notes + ----- + Name-based routing rules (case-insensitive, applied to the final + effective name segment after stripping trailing numeric ParameterList + indices): + + 1. Contains ``"bias"`` -> ``"adam"`` (no weight decay). + 2. Starts with ``"adam_"`` -> ``"adam"`` (no weight decay). + Typical: norm scales, radial frequencies. + 3. Starts with ``"adamw_"`` -> ``"adamw"`` (decoupled weight decay). + Typical: LayerScale parameters. + 4. Otherwise -> ``"muon"`` (eligible for Muon). + """ + if param_name is None: + return "muon" + param_name_lower = param_name.lower() + name_segments = param_name_lower.split(".") + leaf_name_idx = len(name_segments) - 1 + while leaf_name_idx > 0 and name_segments[leaf_name_idx].isdigit(): + leaf_name_idx -= 1 + leaf_name = name_segments[leaf_name_idx] + if "bias" in leaf_name: + return "adam" + if leaf_name.startswith("adam_"): + return "adam" + if leaf_name.startswith("adamw_"): + return "adamw" + return "muon" + + +def get_effective_shape( + shape: torch.Size | tuple[int, ...], +) -> tuple[int, ...]: + """ + Remove singleton dimensions from a tensor shape for routing decisions. + + Parameters + ---------- + shape + Original tensor shape. + + Returns + ------- + tuple[int, ...] + Shape without dimensions equal to 1. + If all dims are 1, returns ``(1,)``. + """ + effective = tuple(int(dim) for dim in shape if int(dim) != 1) + if len(effective) == 0: + return (1,) + return effective + + +def get_matrix_view_shape( + effective_shape: tuple[int, ...], + muon_mode: str, +) -> tuple[int, int, int] | None: + """ + Derive Muon matrix-view shape from effective tensor shape. + + Parameters + ---------- + effective_shape + Shape with singleton dimensions removed. + muon_mode + One of {"2d", "flat", "slice"}. + + Returns + ------- + tuple[int, int, int] | None + ``(batch_size, rows, cols)`` when Muon is applicable, otherwise ``None``. + """ + if len(effective_shape) < 2: + return None + + if muon_mode == "2d": + if len(effective_shape) != 2: + return None + return (1, int(effective_shape[-2]), int(effective_shape[-1])) + if muon_mode == "flat": + rows = int(math.prod(effective_shape[:-1])) + cols = int(effective_shape[-1]) + return (1, rows, cols) + if muon_mode == "slice": + if len(effective_shape) == 2: + return (1, int(effective_shape[-2]), int(effective_shape[-1])) + batch_size = int(math.prod(effective_shape[:-2])) + rows = int(effective_shape[-2]) + cols = int(effective_shape[-1]) + return (batch_size, rows, cols) + raise ValueError( + f"Unsupported muon_mode '{muon_mode}'. Expected one of ['2d', 'flat', 'slice']." + ) class HybridMuonOptimizer(Optimizer): """ - HybridMuon optimizer with small-matrix Adam fallback and 1D Adam path. - - This optimizer applies different update rules based on parameter dimensionality - and ``muon_2d_only``: - - 1D parameters (biases, layer norms): standard Adam update. - - When ``muon_2d_only=True``: - - 2D parameters use Muon/Adam-fallback according to ``min_2d_dim``. - - >2D parameters use Adam. - - When ``muon_2d_only=False``: - - >=2D parameters use matrix-view Muon/Adam-fallback according to - ``min_2d_dim``. + HybridMuon optimizer with 1D Adam path and matrix Muon path. + + This optimizer applies different update rules based on parameter dimensionality, + parameter names, and ``muon_mode``: + - Parameters with final effective name segment containing ``bias`` + (case-insensitive), or starting with ``adam_`` (case-insensitive): + standard Adam update. + - Parameters with final effective name segment starting with ``adamw_`` + (case-insensitive): Adam with decoupled weight decay (AdamW-style). + - 1D parameters: standard Adam update. + - Parameters are routed by effective shape (singleton dimensions removed). + - ``muon_mode="2d"``: + - effective rank 2 parameters use Muon. + - effective rank >2 parameters use Adam. + - ``muon_mode="flat"``: + - effective rank >=2 parameters use flattened matrix-view Muon. + - ``muon_mode="slice"``: + - effective rank 2 parameters use Muon. + - effective rank >=3 parameters apply Muon independently on each trailing + ``(m, n)`` slice. + + Naming convention for explicit Adam routing: + - Parameters representing bias terms should include ``bias`` in their + final effective name segment (case-insensitive). + - Parameters that are not semantic bias but should still use Adam should + use an ``adam_`` prefix in their final effective name segment + (case-insensitive). + - Parameters that should use Adam with decoupled weight decay should use + an ``adamw_`` prefix in their final effective name segment + (case-insensitive). This hybrid approach is effective because Muon's orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params. @@ -387,24 +535,20 @@ class HybridMuonOptimizer(Optimizer): scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust. Default is 10.0 (Adam lr = lr/10). lr_adjust_coeff : float - Dual-purpose coefficient with default 0.2: - 1. For Muon (when lr_adjust <= 0): match-RMS scaling factor, - scale = lr_adjust_coeff * sqrt(max(m, n)). - 2. For matrix Adam fallback: learning rate multiplier, - adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1). - The min(., 0.1) cap ensures conservative updates for small matrices. - muon_2d_only : bool - If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). - Parameters with ndim > 2 use AdamW-style updates. - If False, all >=2D parameters are eligible for Muon via matrix-view routing. - Default is True. - min_2d_dim : int - Minimum min(m, n) threshold for Muon on eligible matrix-view parameters. - Eligible parameters with min(m, n) >= min_2d_dim use Muon; - those with min(m, n) < min_2d_dim use Adam fallback. - Must be >= 1. - Set to 1 to disable fallback. - Default is 1. + Coefficient with default 0.2 for match-RMS scaling when + ``lr_adjust <= 0``: + ``scale = lr_adjust_coeff * sqrt(max(m, n))``. + muon_mode : str + Muon routing mode with default ``"slice"``. + - ``"2d"``: only 2D parameters are Muon candidates. + - ``"flat"``: >=2D parameters use flattened matrix-view routing. + - ``"slice"``: >=3D parameters use per-slice Muon routing on last two dims. + named_parameters : iterable[tuple[str, torch.Tensor]] | None + Optional named parameter iterable used for name-based routing. + Parameters with final effective name segment containing ``bias`` + (case-insensitive), or starting with ``adam_`` (case-insensitive), + are forced to Adam (no weight decay). Parameters starting with + ``adamw_`` are forced to AdamW-style decoupled decay path. flash_muon : bool Enable triton-accelerated Newton-Schulz orthogonalization. Requires triton and CUDA. Falls back to PyTorch implementation @@ -429,13 +573,16 @@ def __init__( adam_betas: tuple[float, float] = (0.9, 0.95), lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, - muon_2d_only: bool = True, - min_2d_dim: int = 1, + muon_mode: str = "slice", + named_parameters: Iterable[tuple[str, torch.Tensor]] | None = None, flash_muon: bool = True, ) -> None: - if min_2d_dim < 1: - raise ValueError("min_2d_dim must be >= 1.") + # === Step 1. Validate routing mode === + muon_mode = str(muon_mode).lower() + if muon_mode not in {"2d", "flat", "slice"}: + raise ValueError("muon_mode must be one of ['2d', 'flat', 'slice'].") + # === Step 2. Register optimizer defaults === defaults = { "lr": lr, "momentum": momentum, @@ -443,15 +590,21 @@ def __init__( "adam_betas": adam_betas, "lr_adjust": lr_adjust, "lr_adjust_coeff": lr_adjust_coeff, - "muon_2d_only": muon_2d_only, - "min_2d_dim": min_2d_dim, + "muon_mode": muon_mode, } super().__init__(params, defaults) + + # === Step 3. Build parameter id -> name mapping === + self._param_name_map: dict[int, str] = {} + if named_parameters is not None: + for name, param in named_parameters: + self._param_name_map[id(param)] = str(name) + # Static parameter routing: built once on first step() call. self._routing_built = False self._routing: list[dict[str, Any]] = [] - # Flash-Muon: triton-accelerated Newton-Schulz + # === Step 4. Flash-Muon setup === self._use_flash = flash_muon and TRITON_AVAILABLE # Lazily allocated NS iteration buffers, keyed by (M, device) self._ns_buffers: dict[ @@ -489,13 +642,14 @@ def _get_ns_buffers( def _build_param_routing(self) -> None: """ - Classify parameters into Muon and Adam routes (static routing). + Classify parameters into Muon, Adam, and AdamW routes (static routing). Routing logic: - - 1D parameters → Adam path - - >2D parameters (when muon_2d_only=True) → Adam path - - >=2D parameters with min(m, n) < min_2d_dim → Adam fallback path - - remaining >=2D parameters → Muon path + - name-based ``adam_`` prefix or contains ``bias`` → Adam (no decay) + - name-based ``adamw_`` prefix → AdamW (decoupled weight decay) + - effective shape rank <2 → Adam (no decay) + - non-matrix effective shape for current muon_mode → AdamW (decoupled) + - remaining eligible matrix params → Muon path """ if self._routing_built: return @@ -503,49 +657,52 @@ def _build_param_routing(self) -> None: self._routing = [] for group in self.param_groups: muon_params: list[dict[str, Any]] = [] - adam_1d: list[dict[str, Any]] = [] - adam_matrix: list[dict[str, Any]] = [] - adam_nd: list[dict[str, Any]] = [] + adam_no_decay: list[dict[str, Any]] = [] + adam_decay: list[dict[str, Any]] = [] - min_2d_dim = group["min_2d_dim"] - muon_2d_only = group["muon_2d_only"] + muon_mode = group["muon_mode"] for p in group["params"]: - # === Step 1. 1D parameters → Adam === - if p.ndim < 2: - adam_1d.append({"param": p}) + param_name = self._param_name_map.get(id(p)) + + # === Step 1. Name-based explicit route === + route = get_adam_route(param_name) + if route == "adam": + adam_no_decay.append({"param": p, "name": param_name}) + continue + if route == "adamw": + adam_decay.append({"param": p, "name": param_name}) continue - # === Step 2. >2D parameters (when muon_2d_only=True) → Adam === - if muon_2d_only and p.ndim > 2: - adam_nd.append({"param": p}) + # === Step 2. Effective <2D parameters → Adam === + effective_shape = get_effective_shape(p.shape) + if len(effective_shape) < 2: + adam_no_decay.append({"param": p, "name": param_name}) continue - # === Step 3. Small matrix-view params → Adam fallback === - if should_fallback_to_adam_for_matrix(p, min_2d_dim=min_2d_dim): - adam_matrix.append( - { - "param": p, - "abs_floor": 1e-3 * math.sqrt(float(p.numel())), - } - ) + # === Step 3. Non-matrix effective shape in current mode → AdamW-style === + matrix_shape = get_matrix_view_shape(effective_shape, muon_mode) + if matrix_shape is None: + adam_decay.append({"param": p, "name": param_name}) continue - # === Step 4. >=2D (or 2D only when muon_2d_only=True) → Muon === + # === Step 4. Eligible matrix params → Muon === + batch_size, rows, cols = matrix_shape muon_params.append( { "param": p, - "rows": int(p.numel() // p.shape[-1]), - "cols": int(p.shape[-1]), + "name": param_name, + "batch_size": batch_size, + "rows": rows, + "cols": cols, } ) self._routing.append( { "muon_params": muon_params, - "adam_1d": adam_1d, - "adam_matrix": adam_matrix, - "adam_nd": adam_nd, + "adam_no_decay": adam_no_decay, + "adam_decay": adam_decay, } ) @@ -554,7 +711,7 @@ def _build_param_routing(self) -> None: @torch.no_grad() def step( self, - closure: callable | None = None, + closure: Callable[[], torch.Tensor] | None = None, ) -> torch.Tensor | None: """ Perform a single optimization step. @@ -586,15 +743,15 @@ def step( lr_adjust = group["lr_adjust"] lr_adjust_coeff = group["lr_adjust_coeff"] - # === Step 1. Adam update for 1D parameters (biases, norms, etc.) === + # === Step 1. Adam update for non-decay Adam path === # === Step 1.1. Collect gradients and initialize state === - adam_params: list[torch.Tensor] = [] - adam_grads_fp32: list[torch.Tensor] = [] - adam_exp_avgs: list[torch.Tensor] = [] - adam_exp_avg_sqs: list[torch.Tensor] = [] - adam_states: list[dict[str, Any]] = [] + adam_no_decay_params: list[torch.Tensor] = [] + adam_no_decay_grads_fp32: list[torch.Tensor] = [] + adam_no_decay_exp_avgs: list[torch.Tensor] = [] + adam_no_decay_exp_avg_sqs: list[torch.Tensor] = [] + adam_no_decay_states: list[dict[str, Any]] = [] - for entry in route["adam_1d"]: + for entry in route["adam_no_decay"]: p = entry["param"] grad = p.grad if grad is None: @@ -612,44 +769,44 @@ def step( state["beta1_pow"] *= adam_betas[0] state["beta2_pow"] *= adam_betas[1] - adam_params.append(p) - adam_grads_fp32.append(grad_fp32) - adam_exp_avgs.append(state["exp_avg"]) - adam_exp_avg_sqs.append(state["exp_avg_sq"]) - adam_states.append(state) + adam_no_decay_params.append(p) + adam_no_decay_grads_fp32.append(grad_fp32) + adam_no_decay_exp_avgs.append(state["exp_avg"]) + adam_no_decay_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_no_decay_states.append(state) - if adam_params: + if adam_no_decay_params: # === Step 1.2. Update exp_avg / exp_avg_sq === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - for ea, g in zip(adam_exp_avgs, adam_grads_fp32): + for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32): ea.lerp_(g, 1 - adam_betas[0]) - grad_sq = [g * g for g in adam_grads_fp32] - for eas, gsq in zip(adam_exp_avg_sqs, grad_sq): + grad_sq = [g * g for g in adam_no_decay_grads_fp32] + for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq): eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 1.3. Bias correction and parameter update === - for i, p in enumerate(adam_params): - state = adam_states[i] + for i, p in enumerate(adam_no_decay_params): + state = adam_no_decay_states[i] bias_corr1 = 1 - state["beta1_pow"] bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr / bias_corr1 # delta = -step_size * m_hat / (sqrt(v_hat) + eps) - denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) - delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) + denom = (adam_no_decay_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + delta_fp32 = -step_size * (adam_no_decay_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 2. Adam update for >2D parameters (when muon_2d_only=True) === + # === Step 2. AdamW-style update for decay-enabled Adam path === # === Step 2.1. Collect gradients and initialize state === - adam_nd_params: list[torch.Tensor] = [] - adam_nd_grads_fp32: list[torch.Tensor] = [] - adam_nd_exp_avgs: list[torch.Tensor] = [] - adam_nd_exp_avg_sqs: list[torch.Tensor] = [] - adam_nd_states: list[dict[str, Any]] = [] + adam_decay_params: list[torch.Tensor] = [] + adam_decay_grads_fp32: list[torch.Tensor] = [] + adam_decay_exp_avgs: list[torch.Tensor] = [] + adam_decay_exp_avg_sqs: list[torch.Tensor] = [] + adam_decay_states: list[dict[str, Any]] = [] - for entry in route.get("adam_nd", []): + for entry in route.get("adam_decay", []): p = entry["param"] grad = p.grad if grad is None: @@ -667,118 +824,41 @@ def step( state["beta1_pow"] *= adam_betas[0] state["beta2_pow"] *= adam_betas[1] - adam_nd_params.append(p) - adam_nd_grads_fp32.append(grad_fp32) - adam_nd_exp_avgs.append(state["exp_avg"]) - adam_nd_exp_avg_sqs.append(state["exp_avg_sq"]) - adam_nd_states.append(state) + adam_decay_params.append(p) + adam_decay_grads_fp32.append(grad_fp32) + adam_decay_exp_avgs.append(state["exp_avg"]) + adam_decay_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_decay_states.append(state) - if adam_nd_params: + if adam_decay_params: # === Step 2.2. Update exp_avg / exp_avg_sq === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust # AdamW decay for >=2D Adam path. if weight_decay > 0: - for p in adam_nd_params: + for p in adam_decay_params: p.mul_(1.0 - lr * weight_decay) # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - for ea, g in zip(adam_nd_exp_avgs, adam_nd_grads_fp32): + for ea, g in zip(adam_decay_exp_avgs, adam_decay_grads_fp32): ea.lerp_(g, 1 - adam_betas[0]) - grad_sq = [g * g for g in adam_nd_grads_fp32] - for eas, gsq in zip(adam_nd_exp_avg_sqs, grad_sq): + grad_sq = [g * g for g in adam_decay_grads_fp32] + for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq): eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 2.3. Bias correction and parameter update === - for i, p in enumerate(adam_nd_params): - state = adam_nd_states[i] + for i, p in enumerate(adam_decay_params): + state = adam_decay_states[i] bias_corr1 = 1 - state["beta1_pow"] bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr / bias_corr1 # delta = -step_size * m_hat / (sqrt(v_hat) + eps) - denom = (adam_nd_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) - delta_fp32 = -step_size * (adam_nd_exp_avgs[i] / denom) + denom = (adam_decay_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + delta_fp32 = -step_size * (adam_decay_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 3. Adam update for small matrix-view params (fallback path) === - # === Step 3.1. Collect gradients and initialize state === - adam_matrix_params: list[torch.Tensor] = [] - adam_matrix_grads_fp32: list[torch.Tensor] = [] - adam_matrix_exp_avgs: list[torch.Tensor] = [] - adam_matrix_exp_avg_sqs: list[torch.Tensor] = [] - adam_matrix_states: list[dict[str, Any]] = [] - adam_matrix_abs_floor: list[float] = [] - - for entry in route["adam_matrix"]: - p = entry["param"] - grad = p.grad - if grad is None: - continue - - grad_fp32 = grad.float() - - state = self.state[p] - if "exp_avg" not in state: - state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) - state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) - state["beta1_pow"] = 1.0 - state["beta2_pow"] = 1.0 - - state["beta1_pow"] *= adam_betas[0] - state["beta2_pow"] *= adam_betas[1] - - adam_matrix_params.append(p) - adam_matrix_grads_fp32.append(grad_fp32) - adam_matrix_exp_avgs.append(state["exp_avg"]) - adam_matrix_exp_avg_sqs.append(state["exp_avg_sq"]) - adam_matrix_states.append(state) - adam_matrix_abs_floor.append(entry["abs_floor"]) - - if adam_matrix_params: - # === Step 3.2. Update exp_avg / exp_avg_sq with scaled lr === - adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust - adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1) - # AdamW decay for matrix fallback path. - if weight_decay > 0: - for p in adam_matrix_params: - p.mul_(1.0 - lr * weight_decay) - - # exp_avg = beta1 * exp_avg + (1 - beta1) * grad - # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - for ea, g in zip(adam_matrix_exp_avgs, adam_matrix_grads_fp32): - ea.lerp_(g, 1 - adam_betas[0]) - grad_sq_m = [g * g for g in adam_matrix_grads_fp32] - for eas, gsq in zip(adam_matrix_exp_avg_sqs, grad_sq_m): - eas.lerp_(gsq, 1 - adam_betas[1]) - - # === Step 3.3. Compute unclipped deltas === - raw_deltas: list[torch.Tensor] = [] - for i in range(len(adam_matrix_params)): - state = adam_matrix_states[i] - bias_corr1 = 1 - state["beta1_pow"] - bias_corr2 = 1 - state["beta2_pow"] - step_size = adam_lr_matrix / bias_corr1 - denom = (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) - raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom)) - - # === Step 3.4. Clip updates by relative norm and apply === - max_rel_change = 0.05 - p_norms = torch.stack([p.norm() for p in adam_matrix_params]) - delta_norms = torch.stack([d.norm() for d in raw_deltas]) - floors = torch.tensor( - adam_matrix_abs_floor, - device=p_norms.device, - dtype=p_norms.dtype, - ) - max_delta = torch.maximum(max_rel_change * p_norms, floors) - scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0) - for i, (p, delta) in enumerate( - zip(adam_matrix_params, raw_deltas, strict=False) - ): - p.add_(delta.mul_(scales_tensor[i]).to(p.dtype)) - - # === Step 4. Muon update for >=2D parameters (weight matrices) === - # === Step 4.1. Collect gradients and initialize momentum === + # === Step 3. Muon update for matrix parameters === + # === Step 3.1. Collect gradients and initialize momentum === muon_params_for_decay: list[torch.Tensor] = [] muon_grads: list[torch.Tensor] = [] muon_momentum_buffers: list[torch.Tensor] = [] @@ -803,7 +883,7 @@ def step( muon_momentum_buffers.append(buf) active_entries.append((entry, grad)) - # === Step 4.2. Apply weight decay on Muon path === + # === Step 3.2. Apply weight decay on Muon path === if weight_decay > 0 and muon_params_for_decay: for p in muon_params_for_decay: p.mul_(1.0 - lr * weight_decay) @@ -811,7 +891,7 @@ def step( if not active_entries: continue - # === Step 4.3. Momentum update (Nesterov) === + # === Step 3.3. Momentum update (Nesterov) === # m_t = beta * m_{t-1} + (1 - beta) * g_t for buf, g in zip(muon_momentum_buffers, muon_grads): buf.lerp_(g, 1 - momentum) @@ -821,22 +901,28 @@ def step( for g, buf in zip(muon_grads, muon_momentum_buffers) ] - # === Step 4.4. Bucket by shape/device/dtype for batched NS === + # === Step 3.4. Bucket by (batch_size, rows, cols, device, dtype) === buckets: dict[ - tuple[int, int, torch.device, torch.dtype], + tuple[int, int, int, torch.device, torch.dtype], list[tuple[dict[str, Any], torch.Tensor]], ] = {} for idx, entry_info in enumerate(active_entries): entry, _ = entry_info p = entry["param"] - bucket_key = (entry["rows"], entry["cols"], p.device, p.dtype) + bucket_key = ( + entry["batch_size"], + entry["rows"], + entry["cols"], + p.device, + p.dtype, + ) if bucket_key not in buckets: buckets[bucket_key] = [] buckets[bucket_key].append((entry, muon_updates[idx])) - # === Step 4.5. Newton-Schulz orthogonalization and update === - for (rows, cols, _device, _), bucket_entries in buckets.items(): + # === Step 3.5. Newton-Schulz orthogonalization and update === + for (batch_size, rows, cols, _device, _), bucket_entries in buckets.items(): # scale = coeff * sqrt(max(m, n)) [match-RMS mode] # scale = sqrt(max(1, m/n)) [rectangular mode] if lr_adjust <= 0: @@ -845,11 +931,15 @@ def step( scale = max(1.0, rows / cols) ** 0.5 # Determine if flash path is usable for this bucket. + # Flash path is enabled only for single-matrix updates. # Only beneficial when min(rows, cols) >= FLASH_MIN_DIM; # for small matrices, triton launch overhead > compute savings. M = min(rows, cols) use_flash = ( - self._use_flash and _device.type == "cuda" and M >= FLASH_MIN_DIM + batch_size == 1 + and self._use_flash + and _device.type == "cuda" + and M >= FLASH_MIN_DIM ) if use_flash: buf1, buf2 = self._get_ns_buffers(M, _device) @@ -857,14 +947,19 @@ def step( # Process each entry individually with Newton-Schulz orth. # Compatible with sharding propagation under FSDP2. for entry, update_tensor in bucket_entries: - update_matrix = update_tensor.reshape(rows, cols) - if not update_matrix.is_contiguous(): - update_matrix = update_matrix.contiguous() - - if use_flash: - orth = _flash_newton_schulz_orth(update_matrix, buf1, buf2) + if batch_size > 1: + update_batch = update_tensor.reshape(batch_size, rows, cols) + if not update_batch.is_contiguous(): + update_batch = update_batch.contiguous() + orth = _batched_newton_schulz_orth(update_batch) else: - orth = _newton_schulz_orth(update_matrix) + update_matrix = update_tensor.reshape(rows, cols) + if not update_matrix.is_contiguous(): + update_matrix = update_matrix.contiguous() + if use_flash: + orth = _flash_newton_schulz_orth(update_matrix, buf1, buf2) + else: + orth = _newton_schulz_orth(update_matrix) orth.mul_(scale) delta = orth.reshape(entry["param"].shape) entry["param"].add_(delta, alpha=-lr) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c13846cbc9..86df4ff648 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -815,9 +815,9 @@ def single_model_finetune( "momentum": float(self.opt_param["momentum"]), "lr_adjust": float(self.opt_param["lr_adjust"]), "lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]), - "muon_2d_only": bool(self.opt_param["muon_2d_only"]), - "min_2d_dim": int(self.opt_param["min_2d_dim"]), - "flash_muon": bool(self.opt_param["flash_muon"]), + "muon_mode": str(self.opt_param.get("muon_mode", "slice")), + "named_parameters": tuple(self.wrapper.named_parameters()), + "flash_muon": bool(self.opt_param.get("flash_muon", True)), } else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 2d20319888..777b5bf79a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2921,7 +2921,19 @@ def optimizer_adamuon() -> list[Argument]: ] -@opt_args_plugin.register("HybridMuon", doc=doc_only_pt_supported) +@opt_args_plugin.register( + "HybridMuon", + doc=doc_only_pt_supported + + "HybridMuon optimizer (DeePMD-kit custom implementation). " + + "This is a Hybrid optimizer that automatically combines Muon and Adam. " + + "For matrix params: Muon update with Newton-Schulz based on selected muon_mode. " + + "For 1D params: Standard Adam. " + + "Name-based Adam routing is enabled: final effective parameter name segment containing 'bias' " + + "or starting with 'adam_' (case-insensitive) always uses Adam (no weight decay); " + + "segment starting with 'adamw_' (case-insensitive) uses AdamW-style decoupled decay. " + + "Trailing numeric ParameterList indices are ignored when deriving the effective segment. " + + "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.", +) def optimizer_hybrid_muon() -> list[Argument]: return [ Argument( @@ -2978,26 +2990,16 @@ def optimizer_hybrid_muon() -> list[Argument]: + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", ), Argument( - "muon_2d_only", - bool, - optional=True, - default=True, - doc=doc_only_pt_supported - + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " - + "Parameters with ndim > 2 use Adam without weight decay. " - + "If False, all >=2D parameters use Muon.", - ), - Argument( - "min_2d_dim", - int, + "muon_mode", + str, optional=True, - default=1, - alias=["muon_min_2d_dim"], + default="slice", doc=doc_only_pt_supported - + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " - "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " - "those with min(m, n) < min_2d_dim use Adam fallback. " - "Set to 1 to disable fallback.", + + "Muon routing mode. " + + "'2d': only effective-rank-2 params are eligible for Muon; effective rank >2 goes to AdamW-style decoupled decay path. " + + "'flat': effective-rank >=2 params are flattened to matrix-view (prod(shape[:-1]), shape[-1]) for Muon. " + + "'slice' (default): effective-rank >=3 params use per-slice Muon on the last two dimensions; no cross-slice mixing. " + + "Routing uses effective shape after removing singleton dimensions.", ), Argument( "flash_muon", diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index f28e014188..017012ac60 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -6,6 +6,7 @@ from deepmd.pt.optimizer.hybrid_muon import ( TRITON_AVAILABLE, HybridMuonOptimizer, + _batched_newton_schulz_orth, _newton_schulz_orth, ) from deepmd.pt.utils import ( @@ -75,6 +76,14 @@ def test_shape_and_dtype(self) -> None: self.assertEqual(X.shape, G.shape) self.assertEqual(X.dtype, torch.bfloat16) + def test_batched_shape_and_dtype(self) -> None: + """Test batched NS preserves shape and returns bf16.""" + torch.manual_seed(42) + G = torch.randn(3, 6, 4, dtype=torch.float32, device=self.device) + X = _batched_newton_schulz_orth(G) + self.assertEqual(X.shape, G.shape) + self.assertEqual(X.dtype, torch.bfloat16) + def test_invalid_input(self) -> None: """Test that 1D input raises error.""" G_1d = torch.randn(10, dtype=torch.float32, device=self.device) @@ -143,30 +152,6 @@ def test_muon_adam_separation(self) -> None: self.assertIn("exp_avg_sq", optimizer.state[model.bias]) self.assertNotIn("momentum_buffer", optimizer.state[model.bias]) - def test_muon_adam_fallback_small_2d(self) -> None: - """Test Adam fallback for small 2D matrices when min_2d_dim is set.""" - torch.manual_seed(42) - linear_small = torch.nn.Linear(10, 1, bias=False, device=self.device) - linear_large = torch.nn.Linear(10, 10, bias=False, device=self.device) - optimizer = HybridMuonOptimizer( - list(linear_small.parameters()) + list(linear_large.parameters()), - lr=0.02, - min_2d_dim=2, - ) - - x = torch.randn(4, 10, device=self.device) - loss = linear_small(x).sum() + linear_large(x).sum() - loss.backward() - optimizer.step() - - # Small 2D weight should use Adam fallback. - self.assertIn("exp_avg", optimizer.state[linear_small.weight]) - self.assertNotIn("momentum_buffer", optimizer.state[linear_small.weight]) - - # Large 2D weight should use Muon. - self.assertIn("momentum_buffer", optimizer.state[linear_large.weight]) - self.assertNotIn("exp_avg", optimizer.state[linear_large.weight]) - def test_lr_adjust_modes(self) -> None: """Test lr_adjust modes: match-RMS (<=0) vs rectangular (>0).""" torch.manual_seed(42) @@ -193,6 +178,133 @@ def test_lr_adjust_modes(self) -> None: "Different lr_adjust modes should produce different updates", ) + def test_slice_mode_uses_muon_for_3d_weight(self) -> None: + """Test muon_mode='slice' + name rules route params as expected.""" + torch.manual_seed(42) + + class ToySliceModule(torch.nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(2, 6, 4, dtype=torch.float32, device=device) + ) + self.adam_scale = torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + self.adam_stack = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + ] + ) + self.adamw_layer_scale = torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + # Contains "bias" (case-insensitive) but not prefix. + self.gateBiAsScale = torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + # Module name contains "bias", but parameter leaf is "weight". + self.bias_proj = torch.nn.Linear(4, 6, bias=False, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.einsum("bi,foi->bfo", x, self.weight) + y = y * self.adam_scale.unsqueeze(0) + y = y * self.adam_stack[0].unsqueeze(0) + y = y * self.adamw_layer_scale.unsqueeze(0) + y = y * self.gateBiAsScale.unsqueeze(0) + y = y + self.bias_proj(x).unsqueeze(1) + return y.sum() + + model = ToySliceModule(self.device) + optimizer = HybridMuonOptimizer( + model.parameters(), + lr=0.02, + muon_mode="slice", + named_parameters=tuple(model.named_parameters()), + ) + + x = torch.randn(4, 4, device=self.device) + model(x).backward() + optimizer.step() + + # 3D weight → Muon (slice mode) + self.assertIn("momentum_buffer", optimizer.state[model.weight]) + self.assertNotIn("exp_avg", optimizer.state[model.weight]) + # adam_ prefix → Adam (no weight decay) + self.assertIn("exp_avg", optimizer.state[model.adam_scale]) + self.assertNotIn("momentum_buffer", optimizer.state[model.adam_scale]) + self.assertIn("exp_avg", optimizer.state[model.adam_stack[0]]) + self.assertNotIn("momentum_buffer", optimizer.state[model.adam_stack[0]]) + # adamw_ prefix → AdamW (decoupled weight decay) + self.assertIn("exp_avg", optimizer.state[model.adamw_layer_scale]) + self.assertNotIn("momentum_buffer", optimizer.state[model.adamw_layer_scale]) + # Contains "bias" (case-insensitive) → Adam + self.assertIn("exp_avg", optimizer.state[model.gateBiAsScale]) + self.assertNotIn("momentum_buffer", optimizer.state[model.gateBiAsScale]) + # Module name "bias_proj" but leaf is "weight" → Muon + self.assertIn("momentum_buffer", optimizer.state[model.bias_proj.weight]) + self.assertNotIn("exp_avg", optimizer.state[model.bias_proj.weight]) + + def test_2d_mode_routes_3d_weight_to_adam(self) -> None: + """Test muon_mode='2d' routes 3D matrix weights to Adam.""" + torch.manual_seed(42) + + class Toy2DModeModule(torch.nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(2, 6, 4, dtype=torch.float32, device=device) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.einsum("bi,foi->bfo", x, self.weight).sum() + + model = Toy2DModeModule(self.device) + optimizer = HybridMuonOptimizer( + model.parameters(), + lr=0.02, + muon_mode="2d", + named_parameters=tuple(model.named_parameters()), + ) + + x = torch.randn(4, 4, device=self.device) + model(x).backward() + optimizer.step() + + self.assertIn("exp_avg", optimizer.state[model.weight]) + self.assertNotIn("momentum_buffer", optimizer.state[model.weight]) + + def test_2d_mode_singleton_3d_routes_to_muon(self) -> None: + """Test muon_mode='2d' treats singleton-expanded matrix as 2D.""" + torch.manual_seed(42) + + class ToySingleton2DModeModule(torch.nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(1, 6, 4, dtype=torch.float32, device=device) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.einsum("bi,foi->bfo", x, self.weight).sum() + + model = ToySingleton2DModeModule(self.device) + optimizer = HybridMuonOptimizer( + model.parameters(), + lr=0.02, + muon_mode="2d", + named_parameters=tuple(model.named_parameters()), + ) + + x = torch.randn(4, 4, device=self.device) + model(x).backward() + optimizer.step() + + self.assertIn("momentum_buffer", optimizer.state[model.weight]) + self.assertNotIn("exp_avg", optimizer.state[model.weight]) + @unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") class TestHybridMuonOptimizerStateDict(unittest.TestCase): From cde4ef77b72bf89335ae010d55dcbdd5f0f0b99a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 21 Feb 2026 17:18:27 +0800 Subject: [PATCH 2/4] feat: add Magma-lite damping for Muon path; fix AdamW decay lr - Implement block-wise momentum-gradient alignment with EMA smoothing and soft scaling [0.1, 1.0] on Muon updates (magma_muon option) - Fix AdamW weight decay to use adam_lr instead of base lr - Wire magma_muon through training config and argcheck - Clean up redundant optimizer tests --- deepmd/pt/optimizer/hybrid_muon.py | 305 +++++++++++++++++++++++++++- deepmd/pt/train/training.py | 1 + deepmd/utils/argcheck.py | 11 + source/tests/pt/test_hybrid_muon.py | 152 ++++++++------ 4 files changed, 400 insertions(+), 69 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index ecd8a4e0c3..75112658fd 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -71,6 +71,14 @@ https://github.com/MoonshotAI/Moonlight .. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) +.. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates. + arXiv:2602.15322, 2025. + https://arxiv.org/abs/2602.15322 + Implements block-wise momentum-gradient alignment scoring with EMA smoothing + and soft scaling for improved stability under heavy-tailed gradient noise. + HybridMuon uses a stabilized variant (Magma-lite) with sigmoid range stretching + and continuous soft scaling [0.1, 1.0] instead of Bernoulli masking, optimized + for MLIP force-field training. """ from __future__ import ( @@ -122,6 +130,13 @@ # Below this threshold, triton kernel launch overhead dominates over compute, # and cuBLAS (via torch.mm/addmm) is faster for small matrices. FLASH_MIN_DIM: int = 1024 +# Magma-lite constants (Muon path update damping only) +MAGMA_TAU: float = 2.0 +MAGMA_EMA_DECAY: float = 0.9 +MAGMA_MIN_SCALE: float = 0.1 +MAGMA_EPS: float = 1e-12 +MAGMA_SIGMOID_MIN: float = 1.0 / (1.0 + math.exp(1.0 / MAGMA_TAU)) +MAGMA_SIGMOID_MAX: float = 1.0 / (1.0 + math.exp(-1.0 / MAGMA_TAU)) # ============================================================================ @@ -554,6 +569,11 @@ class HybridMuonOptimizer(Optimizer): Requires triton and CUDA. Falls back to PyTorch implementation when triton is unavailable or running on CPU. Default is True. + magma_muon : bool + Enable Magma-lite damping on Muon updates with default False. + This computes momentum-gradient cosine alignment per Muon block, + applies EMA smoothing, and rescales Muon updates in [0.1, 1.0]. + Adam/AdamW paths are unchanged. Examples -------- @@ -576,6 +596,7 @@ def __init__( muon_mode: str = "slice", named_parameters: Iterable[tuple[str, torch.Tensor]] | None = None, flash_muon: bool = True, + magma_muon: bool = False, ) -> None: # === Step 1. Validate routing mode === muon_mode = str(muon_mode).lower() @@ -591,6 +612,7 @@ def __init__( "lr_adjust": lr_adjust, "lr_adjust_coeff": lr_adjust_coeff, "muon_mode": muon_mode, + "magma_muon": bool(magma_muon), } super().__init__(params, defaults) @@ -612,6 +634,226 @@ def __init__( tuple[torch.Tensor, torch.Tensor], ] = {} + def _compute_magma_scale( + self, + param: torch.Tensor, + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + batch_size: int, + rows: int, + cols: int, + ) -> torch.Tensor: + """ + Compute Magma-lite Muon damping scales from momentum-gradient alignment. + + Implements a stabilized version of Magma (Momentum-Aligned Gradient Masking) + adapted for MLIP force-field training. Computes block-wise alignment scores + between Muon momentum and current gradients, applies EMA smoothing, and + rescales Muon updates to improve stability under heavy-tailed gradient noise. + + Notes + ----- + For each Muon block b: + + 1. Compute cosine similarity between momentum and gradient: + + cos(b) = <μ_t^(b), g_t^(b)> / (||μ_t^(b)|| * ||g_t^(b)||) + + 2. Apply sigmoid with range stretching to [0, 1]: + + s_raw^(b) = (sigmoid(cos(b) / τ) - s_min) / (s_max - s_min) + + where τ=2.0, s_min=sigmoid(-1/τ), s_max=sigmoid(1/τ). + This stretches the narrow sigmoid range [0.38, 0.62] to [0, 1]. + + 3. Apply EMA smoothing: + + s̃_t^(b) = a * s̃_{t-1}^(b) + (1-a) * s_raw^(b) + + where a=0.9 (MAGMA_EMA_DECAY). + + 4. Map to damping scale in [s_min_scale, 1.0]: + + scale^(b) = s_min_scale + (1 - s_min_scale) * s̃_t^(b) + + where s_min_scale=0.1 (MAGMA_MIN_SCALE). + + 5. Apply damping to Muon update: + + Δ̃^(b) = scale^(b) * Δ^(b) (soft scaling, no Bernoulli masking) + + Key differences from the original Magma paper: + + - Sigmoid range stretching: Paper uses raw sigmoid with narrow range [0.38, 0.62]. + We stretch to [0, 1] for better discrimination between aligned/misaligned blocks. + - Soft scaling: Paper uses Bernoulli masking (50% skip probability). + We use continuous soft scaling [0.1, 1.0] for stability in MLIP training. + - Minimum scale: Paper allows scale=0 (complete skip). + We enforce scale >= 0.1 to guarantee minimum learning rate. + + Parameters + ---------- + param : torch.Tensor + Parameter updated by Muon. + grad : torch.Tensor + Current gradient tensor with shape compatible with ``(batch_size, rows, cols)``. + momentum_buffer : torch.Tensor + Muon momentum buffer (updated m_t) with same shape as ``grad``. + batch_size : int + Number of Muon blocks (1 for 2d/flat mode, >1 for slice mode). + rows : int + Matrix row count per block. + cols : int + Matrix column count per block. + + Returns + ------- + torch.Tensor + Damping scales with shape (batch_size,) in [MAGMA_MIN_SCALE, 1.0]. + """ + # === Step 1. Restore or initialize EMA score state === + state = self.state[param] + magma_score = state.get("magma_score") + if ( + magma_score is None + or magma_score.ndim != 1 + or magma_score.numel() != batch_size + or magma_score.device != param.device + ): + magma_score = torch.full( + (batch_size,), + 0.5, + dtype=torch.float32, + device=param.device, + ) + else: + magma_score = magma_score.to(dtype=torch.float32, device=param.device) + + # === Step 2. Build matrix-view for block-wise cosine === + grad_view = grad.reshape(batch_size, rows, cols).reshape(batch_size, -1) + momentum_view = momentum_buffer.reshape(batch_size, rows, cols).reshape( + batch_size, -1 + ) + grad_view = grad_view.to(dtype=torch.float32) + momentum_view = momentum_view.to(dtype=torch.float32) + + # === Step 3. Compute cosine alignment with numerical protection === + dot = (momentum_view * grad_view).sum(dim=1) + denom = (momentum_view.norm(dim=1) * grad_view.norm(dim=1)).clamp(min=MAGMA_EPS) + cosine = (dot / denom).clamp(min=-1.0, max=1.0) + + # === Step 4. Sigmoid mapping + range stretching to [0, 1] === + raw_sigmoid = torch.sigmoid(cosine / MAGMA_TAU) + raw_score = (raw_sigmoid - MAGMA_SIGMOID_MIN) / ( + MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN + ) + raw_score = raw_score.clamp(min=0.0, max=1.0) + + # === Step 5. Update EMA score and convert to damping scale === + magma_score = ( + MAGMA_EMA_DECAY * magma_score + (1.0 - MAGMA_EMA_DECAY) * raw_score + ) + state["magma_score"] = magma_score + return MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE) * magma_score + + def _compute_magma_scales_for_bucket( + self, + bucket_entries: list[ + tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor] + ], + batch_size: int, + rows: int, + cols: int, + ) -> list[torch.Tensor]: + """ + Compute Magma-lite damping scales for one Muon bucket in a batched way. + + Parameters + ---------- + bucket_entries : list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]] + Bucket entries as ``(entry, update_tensor, grad, momentum_buffer)``. + batch_size : int + Number of Muon blocks per parameter in this bucket. + rows : int + Matrix row count for this bucket. + cols : int + Matrix column count for this bucket. + + Returns + ------- + list[torch.Tensor] + Magma scales for each bucket entry. Each tensor has shape (batch_size,). + """ + # === Step 0. Fast path for single-entry bucket === + if len(bucket_entries) == 1: + entry, _update_tensor, grad, momentum_buffer = bucket_entries[0] + return [ + self._compute_magma_scale( + param=entry["param"], + grad=grad, + momentum_buffer=momentum_buffer, + batch_size=batch_size, + rows=rows, + cols=cols, + ) + ] + + # === Step 1. Build batched matrix views === + grad_views: list[torch.Tensor] = [] + momentum_views: list[torch.Tensor] = [] + for _, _, grad, momentum_buffer in bucket_entries: + grad_view = grad.reshape(batch_size, rows, cols).reshape(batch_size, -1) + momentum_view = momentum_buffer.reshape(batch_size, rows, cols).reshape( + batch_size, -1 + ) + grad_views.append(grad_view.to(dtype=torch.float32)) + momentum_views.append(momentum_view.to(dtype=torch.float32)) + + grad_batch = torch.stack(grad_views, dim=0) + momentum_batch = torch.stack(momentum_views, dim=0) + + # === Step 2. Compute cosine alignment for all entries === + dot = (momentum_batch * grad_batch).sum(dim=2) + denom = (momentum_batch.norm(dim=2) * grad_batch.norm(dim=2)).clamp( + min=MAGMA_EPS + ) + cosine = (dot / denom).clamp(min=-1.0, max=1.0) + raw_sigmoid = torch.sigmoid(cosine / MAGMA_TAU) + raw_scores = (raw_sigmoid - MAGMA_SIGMOID_MIN) / ( + MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN + ) + raw_scores = raw_scores.clamp(min=0.0, max=1.0) + + # === Step 3. Update per-parameter EMA score state === + scales: list[torch.Tensor] = [] + for idx, (entry, _, _, _) in enumerate(bucket_entries): + param = entry["param"] + state = self.state[param] + magma_score = state.get("magma_score") + if ( + magma_score is None + or magma_score.ndim != 1 + or magma_score.numel() != batch_size + or magma_score.device != param.device + ): + magma_score = torch.full( + (batch_size,), + 0.5, + dtype=torch.float32, + device=param.device, + ) + state["magma_score"] = magma_score + elif magma_score.dtype != torch.float32: + magma_score = magma_score.to(dtype=torch.float32, device=param.device) + state["magma_score"] = magma_score + + magma_score.mul_(MAGMA_EMA_DECAY).add_( + raw_scores[idx], alpha=(1.0 - MAGMA_EMA_DECAY) + ) + scales.append(MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE) * magma_score) + + return scales + def _get_ns_buffers( self, M: int, @@ -742,6 +984,7 @@ def step( adam_betas = group["adam_betas"] lr_adjust = group["lr_adjust"] lr_adjust_coeff = group["lr_adjust_coeff"] + magma_muon = bool(group.get("magma_muon", False)) # === Step 1. Adam update for non-decay Adam path === # === Step 1.1. Collect gradients and initialize state === @@ -836,7 +1079,7 @@ def step( # AdamW decay for >=2D Adam path. if weight_decay > 0: for p in adam_decay_params: - p.mul_(1.0 - lr * weight_decay) + p.mul_(1.0 - adam_lr * weight_decay) # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 @@ -904,7 +1147,7 @@ def step( # === Step 3.4. Bucket by (batch_size, rows, cols, device, dtype) === buckets: dict[ tuple[int, int, int, torch.device, torch.dtype], - list[tuple[dict[str, Any], torch.Tensor]], + list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]], ] = {} for idx, entry_info in enumerate(active_entries): @@ -919,7 +1162,14 @@ def step( ) if bucket_key not in buckets: buckets[bucket_key] = [] - buckets[bucket_key].append((entry, muon_updates[idx])) + buckets[bucket_key].append( + ( + entry, + muon_updates[idx], + muon_grads[idx], + muon_momentum_buffers[idx], + ) + ) # === Step 3.5. Newton-Schulz orthogonalization and update === for (batch_size, rows, cols, _device, _), bucket_entries in buckets.items(): @@ -944,24 +1194,57 @@ def step( if use_flash: buf1, buf2 = self._get_ns_buffers(M, _device) + if magma_muon: + bucket_magma_scales = self._compute_magma_scales_for_bucket( + bucket_entries=bucket_entries, + batch_size=batch_size, + rows=rows, + cols=cols, + ) + else: + bucket_magma_scales = [None] * len(bucket_entries) + # Process each entry individually with Newton-Schulz orth. # Compatible with sharding propagation under FSDP2. - for entry, update_tensor in bucket_entries: + for (entry, update_tensor, _grad, _buffer), magma_scale in zip( + bucket_entries, bucket_magma_scales, strict=True + ): if batch_size > 1: - update_batch = update_tensor.reshape(batch_size, rows, cols) - if not update_batch.is_contiguous(): - update_batch = update_batch.contiguous() + if update_tensor.is_contiguous(): + update_batch = update_tensor.view(batch_size, rows, cols) + else: + update_batch = update_tensor.reshape( + batch_size, rows, cols + ).contiguous() orth = _batched_newton_schulz_orth(update_batch) else: - update_matrix = update_tensor.reshape(rows, cols) - if not update_matrix.is_contiguous(): - update_matrix = update_matrix.contiguous() + if update_tensor.is_contiguous(): + update_matrix = update_tensor.view(rows, cols) + else: + update_matrix = update_tensor.reshape( + rows, cols + ).contiguous() if use_flash: orth = _flash_newton_schulz_orth(update_matrix, buf1, buf2) else: orth = _newton_schulz_orth(update_matrix) orth.mul_(scale) - delta = orth.reshape(entry["param"].shape) + if batch_size > 1: + orth_view = orth.reshape(batch_size, rows, cols) + if magma_scale is not None: + orth_view.mul_( + magma_scale.view(batch_size, 1, 1).to( + dtype=orth.dtype, + device=orth.device, + ) + ) + delta = orth_view.reshape(entry["param"].shape) + else: + if magma_scale is not None: + orth.mul_( + magma_scale[0].to(dtype=orth.dtype, device=orth.device) + ) + delta = orth.reshape(entry["param"].shape) entry["param"].add_(delta, alpha=-lr) return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 86df4ff648..4a0734b255 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -818,6 +818,7 @@ def single_model_finetune( "muon_mode": str(self.opt_param.get("muon_mode", "slice")), "named_parameters": tuple(self.wrapper.named_parameters()), "flash_muon": bool(self.opt_param.get("flash_muon", True)), + "magma_muon": bool(self.opt_param.get("magma_muon", False)), } else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 777b5bf79a..7f6e4bfc5d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3011,6 +3011,17 @@ def optimizer_hybrid_muon() -> list[Argument]: "Requires triton and CUDA. Falls back to PyTorch implementation " "when triton is unavailable or running on CPU.", ), + Argument( + "magma_muon", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + + "Enable Magma-lite damping on the Muon route only. " + "When enabled, HybridMuon computes momentum-gradient alignment " + "per Muon block, applies EMA smoothing, and rescales Muon updates " + "to improve stability. Adam/AdamW routes are unchanged.", + ), ] diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 017012ac60..85321c305d 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -4,9 +4,9 @@ import torch from deepmd.pt.optimizer.hybrid_muon import ( + MAGMA_MIN_SCALE, TRITON_AVAILABLE, HybridMuonOptimizer, - _batched_newton_schulz_orth, _newton_schulz_orth, ) from deepmd.pt.utils import ( @@ -67,23 +67,6 @@ def test_orthogonalization(self) -> None: off_diag_norm, 1.5, f"Off-diagonal norm too large: {off_diag_norm}" ) - def test_shape_and_dtype(self) -> None: - """Test that output preserves shape and returns bf16.""" - torch.manual_seed(42) - for shape in [(4, 4), (6, 4)]: - G = torch.randn(*shape, dtype=torch.float32, device=self.device) - X = _newton_schulz_orth(G) - self.assertEqual(X.shape, G.shape) - self.assertEqual(X.dtype, torch.bfloat16) - - def test_batched_shape_and_dtype(self) -> None: - """Test batched NS preserves shape and returns bf16.""" - torch.manual_seed(42) - G = torch.randn(3, 6, 4, dtype=torch.float32, device=self.device) - X = _batched_newton_schulz_orth(G) - self.assertEqual(X.shape, G.shape) - self.assertEqual(X.dtype, torch.bfloat16) - def test_invalid_input(self) -> None: """Test that 1D input raises error.""" G_1d = torch.randn(10, dtype=torch.float32, device=self.device) @@ -98,27 +81,6 @@ class TestHybridMuonOptimizer(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE - def test_step(self) -> None: - """Test basic optimizer step changes parameters.""" - torch.manual_seed(42) - model = torch.nn.Sequential( - torch.nn.Linear(10, 20, device=self.device), - torch.nn.ReLU(), - torch.nn.Linear(20, 5, device=self.device), - ) - optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) - - x = torch.randn(4, 10, device=self.device) - model(x).sum().backward() - - initial_params = [p.clone() for p in model.parameters()] - optimizer.step() - - for i, (p, init_p) in enumerate( - zip(model.parameters(), initial_params, strict=True) - ): - self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change") - def test_weight_decay(self) -> None: """Test weight decay reduces parameter norm.""" torch.manual_seed(42) @@ -305,6 +267,99 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertIn("momentum_buffer", optimizer.state[model.weight]) self.assertNotIn("exp_avg", optimizer.state[model.weight]) + def test_magma_muon_slice_state_and_range(self) -> None: + """Test magma_muon creates bounded per-slice scores on Muon path.""" + torch.manual_seed(42) + + class ToyMagmaSlice(torch.nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(2, 6, 4, dtype=torch.float32, device=device) + ) + self.adam_scale = torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.einsum("bi,foi->bfo", x, self.weight) + y = y * self.adam_scale.unsqueeze(0) + return y.sum() + + model = ToyMagmaSlice(self.device) + optimizer = HybridMuonOptimizer( + model.parameters(), + lr=0.02, + muon_mode="slice", + named_parameters=tuple(model.named_parameters()), + magma_muon=True, + ) + + x = torch.randn(4, 4, device=self.device) + optimizer.zero_grad() + model(x).backward() + optimizer.step() + + score = optimizer.state[model.weight]["magma_score"] + self.assertEqual(score.shape, (2,)) + self.assertTrue(torch.all(score >= 0.0)) + self.assertTrue(torch.all(score <= 1.0)) + scale = MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE) * score + self.assertTrue(torch.all(scale >= MAGMA_MIN_SCALE)) + self.assertTrue(torch.all(scale <= 1.0)) + self.assertNotIn("magma_score", optimizer.state[model.adam_scale]) + + def test_magma_muon_only_affects_muon_path(self) -> None: + """Test Magma damping changes Muon updates but keeps Adam path unchanged.""" + torch.manual_seed(42) + + class ToyMagmaMixed(torch.nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(2, 6, 4, dtype=torch.float32, device=device) + ) + self.adam_scale = torch.nn.Parameter( + torch.ones(2, 6, dtype=torch.float32, device=device) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.einsum("bi,foi->bfo", x, self.weight) + y = y * self.adam_scale.unsqueeze(0) + return y.sum() + + model1 = ToyMagmaMixed(self.device) + model2 = ToyMagmaMixed(self.device) + model2.load_state_dict(model1.state_dict()) + + opt_off = HybridMuonOptimizer( + model1.parameters(), + lr=0.02, + muon_mode="slice", + named_parameters=tuple(model1.named_parameters()), + magma_muon=False, + ) + opt_on = HybridMuonOptimizer( + model2.parameters(), + lr=0.02, + muon_mode="slice", + named_parameters=tuple(model2.named_parameters()), + magma_muon=True, + ) + + x = torch.randn(4, 4, device=self.device) + opt_off.zero_grad() + model1(x).backward() + opt_off.step() + opt_on.zero_grad() + model2(x).backward() + opt_on.step() + + self.assertFalse(torch.allclose(model1.weight, model2.weight)) + self.assertTrue( + torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) + ) + @unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") class TestHybridMuonOptimizerStateDict(unittest.TestCase): @@ -349,25 +404,6 @@ class TestFlashMuon(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE - def test_flash_muon_false_runs(self) -> None: - """Test that flash_muon=False uses pure PyTorch path without error.""" - torch.manual_seed(42) - model = torch.nn.Linear(10, 20, device=self.device) - optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=False) - x = torch.randn(4, 10, device=self.device) - model(x).sum().backward() - optimizer.step() - # Should complete without error - - def test_flash_muon_true_runs(self) -> None: - """Test that flash_muon=True runs (falls back on CPU, uses triton on CUDA).""" - torch.manual_seed(42) - model = torch.nn.Linear(10, 20, device=self.device) - optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=True) - x = torch.randn(4, 10, device=self.device) - model(x).sum().backward() - optimizer.step() - def test_flash_vs_pytorch_consistency(self) -> None: """Test that flash and non-flash paths produce consistent results. From d29ee0c35d1c36baeae51432a0ebcd160439946e Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 11 Mar 2026 11:38:02 +0800 Subject: [PATCH 3/4] change default values --- deepmd/pt/optimizer/hybrid_muon.py | 10 +++++----- deepmd/utils/argcheck.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index 75112658fd..871a2cf613 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -21,7 +21,7 @@ - ``muon_mode="flat"``: - >=2D matrix parameters use flattened matrix-view routing: ``(rows, cols) = (prod(effective_shape[:-1]), effective_shape[-1])``. -- ``muon_mode="slice"`` (default): +- ``muon_mode="slice"``: - Effective rank 2 matrix parameters: same as ``"2d"``. - Effective rank >=3 matrix parameters: treat leading axes as batch and apply Muon independently on each ``(..., m, n)`` slice (no cross-slice mixing). @@ -533,7 +533,7 @@ class HybridMuonOptimizer(Optimizer): params : iterable Iterable of parameters to optimize. lr : float - Learning rate with default 1e-3. + Learning rate. momentum : float Momentum coefficient for Muon with default 0.95. weight_decay : float @@ -577,7 +577,7 @@ class HybridMuonOptimizer(Optimizer): Examples -------- - >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=1e-3) + >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=5e-4) >>> for epoch in range(epochs): ... optimizer.zero_grad() ... loss.backward() @@ -587,11 +587,11 @@ class HybridMuonOptimizer(Optimizer): def __init__( self, params: Iterable[torch.Tensor] | Iterable[dict[str, Any]], - lr: float = 1e-3, + lr: float = 5e-4, momentum: float = 0.95, weight_decay: float = 0.001, adam_betas: tuple[float, float] = (0.9, 0.95), - lr_adjust: float = 10.0, + lr_adjust: float = 0.0, lr_adjust_coeff: float = 0.2, muon_mode: str = "slice", named_parameters: Iterable[tuple[str, torch.Tensor]] | None = None, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7f6e4bfc5d..de655cb06d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2974,12 +2974,12 @@ def optimizer_hybrid_muon() -> list[Argument]: "lr_adjust", float, optional=True, - default=10.0, + default=0.0, doc=doc_only_pt_supported + "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. " "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " - "Default is 10.0 (Adam lr = lr/10).", + "Default is 0.0 (match-RMS scaling).", ), Argument( "lr_adjust_coeff", From a2c52c6c9727448988cfbd3e9157c923db9b5cf6 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 11 Mar 2026 11:44:31 +0800 Subject: [PATCH 4/4] fixup --- deepmd/pt/optimizer/hybrid_muon.py | 26 ++++++++++++++------------ source/tests/pt/test_hybrid_muon.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index 871a2cf613..4f30cab89a 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -72,7 +72,7 @@ .. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) .. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates. - arXiv:2602.15322, 2025. + arXiv:2602.15322, 2026. https://arxiv.org/abs/2602.15322 Implements block-wise momentum-gradient alignment scoring with EMA smoothing and soft scaling for improved stability under heavy-tailed gradient noise. @@ -340,9 +340,7 @@ def _batched_newton_schulz_orth( """ # === Step 1. Validate and prepare matrix orientation === if G.ndim != 3: - raise ValueError( - "Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)." - ) + raise ValueError("Batched Newton-Schulz expects a 3D tensor (B, m, n).") X = G.to(dtype=torch.bfloat16) transposed = X.size(-2) > X.size(-1) @@ -474,9 +472,7 @@ def get_matrix_view_shape( rows = int(effective_shape[-2]) cols = int(effective_shape[-1]) return (batch_size, rows, cols) - raise ValueError( - f"Unsupported muon_mode '{muon_mode}'. Expected one of ['2d', 'flat', 'slice']." - ) + raise ValueError(f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'.") class HybridMuonOptimizer(Optimizer): @@ -601,7 +597,9 @@ def __init__( # === Step 1. Validate routing mode === muon_mode = str(muon_mode).lower() if muon_mode not in {"2d", "flat", "slice"}: - raise ValueError("muon_mode must be one of ['2d', 'flat', 'slice'].") + raise ValueError( + f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'." + ) # === Step 2. Register optimizer defaults === defaults = { @@ -1024,10 +1022,12 @@ def step( # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32): + for ea, g in zip( + adam_no_decay_exp_avgs, adam_no_decay_grads_fp32, strict=True + ): ea.lerp_(g, 1 - adam_betas[0]) grad_sq = [g * g for g in adam_no_decay_grads_fp32] - for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq): + for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq, strict=True): eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 1.3. Bias correction and parameter update === @@ -1083,10 +1083,12 @@ def step( # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - for ea, g in zip(adam_decay_exp_avgs, adam_decay_grads_fp32): + for ea, g in zip( + adam_decay_exp_avgs, adam_decay_grads_fp32, strict=True + ): ea.lerp_(g, 1 - adam_betas[0]) grad_sq = [g * g for g in adam_decay_grads_fp32] - for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq): + for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq, strict=True): eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 2.3. Bias correction and parameter update === diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 85321c305d..a10a9dcf58 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -357,7 +357,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertFalse(torch.allclose(model1.weight, model2.weight)) self.assertTrue( - torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) + torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-7) )