From 7fab90c729abd46ac8658cc0e713fc5f7124b6f9 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 18 Jul 2025 15:42:30 +0900 Subject: [PATCH 01/22] Add cutlass-python-dsl executor. Starting with quack's softmax Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 147 ++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 thunder/executors/cutlass_dsl_ex.py diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py new file mode 100644 index 0000000000..c817eb886b --- /dev/null +++ b/thunder/executors/cutlass_dsl_ex.py @@ -0,0 +1,147 @@ +from __future__ import annotations +from importlib.metadata import version, PackageNotFoundError +from importlib.util import find_spec +import warnings +from typing import TYPE_CHECKING + +from looseversion import LooseVersion +import torch + +from thunder.core.transforms import get_grad, put_grad +from thunder.extend import register_executor, OperatorExecutor +from thunder.core.proxies import TensorProxy +import thunder.torch as ltorch + +if TYPE_CHECKING: + from thunder.core.dtypes import dtype as thunder_dtype + + +__all__ = [ + "cutlass_dsl_available", + "cutlass_dsl_version", + "cutlass_dsl_ex", +] + + +def cutlass_dsl_version() -> LooseVersion | None: + """Returns ``cutlass`` version if available, otherwise, :obj:`None`""" + + if not torch.cuda.is_available(): + return None + + if find_spec("cutlass") is None: + return None + + # First, check if it's cutlass>=4.0.0 which has the distribution name of nvidia-cutlass-dsl + # ref: https://pypi.org/project/nvidia-cutlass-dsl/ + cutlass_python_version: LooseVersion + nvidia_cutlass_dsl_version: str | None = None + nvidia_cutlass_version: str | None = None + try: + nvidia_cutlass_dsl_version = version("nvidia-cutlass-dsl") + except PackageNotFoundError: + try: + # Then check if it's <4 which has the name of nvidia-cutlass + # ref: https://pypi.org/project/nvidia-cutlass/ + nvidia_cutlass_version = version("nvidia-cutlass") + except PackageNotFoundError: + return None + else: + cutlass_python_version = LooseVersion(nvidia_cutlass_version) + else: + cutlass_python_version = LooseVersion(nvidia_cutlass_dsl_version) + + return cutlass_python_version + + +def required_cutlass_dsl_version() -> LooseVersion: + return LooseVersion("4.0.0") + + +def cutlass_dsl_available() -> bool: + ver = cutlass_dsl_version() + + if ver is None: + return False + + if ver < required_cutlass_dsl_version(): + msg = f"Available cutlass version is out of date. Thunder requires 4.0.0, but found {ver}" + warnings.warn(msg) + return False + + return True + + +cutlass_dsl_ex = OperatorExecutor("cutlass_dsl", version=cutlass_dsl_version()) +register_executor(cutlass_dsl_ex) + + +# Register [`quack`](https://github.com/Dao-AILab/quack) ops +if find_spec("quack") is not None: + from quack.softmax import _softmax_fwd, _softmax_backward + + def quack_softmax_impl(a: torch.Tensor) -> torch.Tensor: + return _softmax_fwd(a) + + def quack_softmax_meta(a: TensorProxy) -> TensorProxy: + return TensorProxy(like=a) + + quack_softmax = cutlass_dsl_ex.register_operator( + "cutlass_quack_softmax_forward", + meta=quack_softmax_meta, + fn=quack_softmax_impl, + ) + + def quack_softmax_backward(g: torch.Tensor, a: torch.Tensor) -> torch.Tensor: + return _softmax_backward(g, a) + + def quack_softmax_backward_meta(g: TensorProxy, a: TensorProxy) -> TensorProxy: + return TensorProxy(like=g) + + quack_softmax_backward = cutlass_dsl_ex.register_operator( + "cutlass_quack_softmax_backward", + meta=quack_softmax_backward_meta, + fn=quack_softmax_backward, + ) + + def quack_softmax_checker( + a: TensorProxy, + /, + dim: int, + *, + dtype: thunder_dtype | None = None, + ) -> bool: + last_dims = {-1, a.ndim - 1} + allowed_dtypes = {None, a.dtype} + return dim in last_dims and dtype in allowed_dtypes and torch.cuda.get_device_capability() in ((9, 0), (10, 0)) + + def quack_softmax_transform( + a: TensorProxy, + /, + dim: int, + *, + dtype: thunder_dtype | None = None, + ) -> TensorProxy: + return quack_softmax(a) + + def quack_softmax_grad( + a: TensorProxy, + /, + dim: int, + *, + dtype: thunder_dtype | None = None, + ) -> TensorProxy: + fwd = quack_softmax(a) + g = get_grad(fwd) + a_grad = quack_softmax_backward(g, fwd) + put_grad(a, a_grad) + + return fwd + + for ltorch_softmax in (ltorch._softmax, ltorch.softmax): + cutlass_dsl_ex.register_implementation( + ltorch_softmax, + checker=quack_softmax_checker, + execution_transform=quack_softmax_transform, + grad_transform=quack_softmax_grad, + ) From 2f42fa2c0c8fc866aefad322ee1777d94214abe7 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 18 Jul 2025 18:06:12 +0900 Subject: [PATCH 02/22] [no ci] add crossentropy Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 130 +++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index c817eb886b..30c5f489be 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -13,6 +13,7 @@ import thunder.torch as ltorch if TYPE_CHECKING: + from typing import Any from thunder.core.dtypes import dtype as thunder_dtype @@ -76,8 +77,13 @@ def cutlass_dsl_available() -> bool: register_executor(cutlass_dsl_ex) +def is_device_quack_compat() -> bool: + return torch.cuda.get_device_capability() in ((9, 0), (10, 0)) + + # Register [`quack`](https://github.com/Dao-AILab/quack) ops if find_spec("quack") is not None: + # softmax from quack.softmax import _softmax_fwd, _softmax_backward def quack_softmax_impl(a: torch.Tensor) -> torch.Tensor: @@ -113,7 +119,7 @@ def quack_softmax_checker( ) -> bool: last_dims = {-1, a.ndim - 1} allowed_dtypes = {None, a.dtype} - return dim in last_dims and dtype in allowed_dtypes and torch.cuda.get_device_capability() in ((9, 0), (10, 0)) + return dim in last_dims and dtype in allowed_dtypes and is_device_quack_compat() def quack_softmax_transform( a: TensorProxy, @@ -145,3 +151,125 @@ def quack_softmax_grad( execution_transform=quack_softmax_transform, grad_transform=quack_softmax_grad, ) + + # crossentropy + from quack.cross_entropy import _cross_entropy, _cross_entropy_backward + + def quack_cross_entropy_forward_impl( + x: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + return _cross_entropy(x, target, return_lse=False) + + def quack_cross_entropy_forward_meta(x: TensorProxy, target: TensorProxy) -> TensorProxy: + return TensorProxy(like=x, shape=(x.shape[0],)) + + quack_cross_entropy_forward = cutlass_dsl_ex.register_operator( + "cutlass_quack_cross_entropy_forward", + meta=quack_cross_entropy_forward_meta, + fn=quack_cross_entropy_forward_impl, + ) + + def quack_cross_entropy_backward_impl( + x: torch.Tensor, + target: torch.Tensor, + grad: torch.Tensor, + lse: torch.Tensor, + ) -> torch.Tensor: + return _cross_entropy_backward(x, target, grad, lse, False) + + def quack_cross_entropy_backward_meta( + x: TensorProxy, + target: TensorProxy, + grad: TensorProxy, + lse: TensorProxy, + ) -> TensorProxy: + return TensorProxy(like=grad) + + quack_cross_entropy_backward = cutlass_dsl_ex.register_operator( + "cutlass_quack_cross_entropy_backward", + meta=quack_softmax_backward_meta, + fn=quack_cross_entropy_backward_impl, + ) + + def quack_cross_entropy_checker( + a: TensorProxy, + /, + target: TensorProxy, + weight: TensorProxy | None = None, + size_average: bool | None = None, + ignore_index: int = -100, + reduce: bool | None = None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> bool: + if not is_device_quack_compat(): + return False + if weight is not None: + return False + + # Assert deprecated flags are not used + for boolean_flag in (size_average, reduce): + if boolean_flag is not None: + return False + + if reduction != "none": + return False + + if label_smoothing != 0.0: + return False + + return True + + def quack_cross_entropy_transform( + a: TensorProxy, + /, + target: TensorProxy, + weight: TensorProxy | None = None, + size_average: bool | None = None, + ignore_index: int = -100, + reduce: bool | None = None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> TensorProxy: + return quack_cross_entropy_forward(a, target) + + def quack_cross_entropy_aug_forward_impl( + x: torch.Tensor, + target: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return _cross_entropy(x, target, return_lse=True) + + def quack_cross_entropy_aug_forward_meta(a: TensorProxy, target: TensorProxy) -> tuple[TensorProxy, TensorProxy]: + return (TensorProxy(like=a, shape=(a.shape[0],)), TensorProxy(like=a, shape=(a.shape[0],))) + + quack_cross_entropy_aug_forward = cutlass_dsl_ex.register_operator( + "cutlass_quack_cross_entropy_aug_forward", + meta=quack_cross_entropy_aug_forward_meta, + fn=quack_cross_entropy_aug_forward_impl, + ) + + def quack_cross_entropy_grad( + a: TensorProxy, + /, + target: TensorProxy, + weight: TensorProxy | None = None, + size_average: bool | None = None, + ignore_index: int = -100, + reduce: bool | None = None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> TensorProxy: + fwd, lse = quack_cross_entropy_aug_forward(a, target) + g = get_grad(fwd) + a_grad = quack_cross_entropy_backward(a, target, g, lse) + put_grad(a, a_grad) + + return fwd + + cutlass_dsl_ex.register_implementation( + ltorch.cross_entropy, + checker=quack_cross_entropy_checker, + execution_transform=quack_cross_entropy_transform, + grad_transform=quack_cross_entropy_grad, + ) From b95c9e4506b074c8eb22b281be1122ed59ab98b0 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 18 Jul 2025 18:30:22 +0900 Subject: [PATCH 03/22] [no ci] add layer norm forward Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 56 ++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 30c5f489be..54f889a9f2 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -13,7 +13,8 @@ import thunder.torch as ltorch if TYPE_CHECKING: - from typing import Any + from collections.abc import Sequence + from numbers import Number from thunder.core.dtypes import dtype as thunder_dtype @@ -273,3 +274,56 @@ def quack_cross_entropy_grad( execution_transform=quack_cross_entropy_transform, grad_transform=quack_cross_entropy_grad, ) + + # layernorm (only forward as of https://github.com/Dao-AILab/quack/commit/3ce89a24) + from quack.layernorm import layernorm + + def quack_layer_norm_forward_impl( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + return_rstd: bool, + return_mean: bool, + ) -> torch.Tensor: + return layernorm(x, weight, eps, return_rstd=return_rstd, return_mean=return_mean) + + def quack_layer_norm_forward_meta( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + return_rstd: bool, + return_mean: bool, + ) -> TensorProxy: + return TensorProxy(like=x) + + quack_layer_norm_forward = cutlass_dsl_ex.register_operator( + "cutlass_quack_layer_norm_forward", + meta=quack_layer_norm_forward_meta, + fn=quack_layer_norm_forward_impl, + ) + + def quack_layer_norm_checker( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + bias: TensorProxy | None = None, + eps: Number = 1e-5, + ) -> bool: + return is_device_quack_compat() + + def quack_layer_norm_transform( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + bias: TensorProxy | None = None, + eps: Number = 1e-5, + ) -> TensorProxy: + return quack_layer_norm_forward(a, weight, eps) + + cutlass_dsl_ex.register_operator( + ltorch.layer_norm, + checker=quack_layer_norm_checker, + execution_transform=quack_layer_norm_transform, + ) From a9eaf4ddb7886971c1cbdddb033ca2e3f9b42125 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 18 Jul 2025 18:48:26 +0900 Subject: [PATCH 04/22] [no ci] add rmsnorm Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 114 +++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 3 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 54f889a9f2..206e02b23f 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -288,8 +288,8 @@ def quack_layer_norm_forward_impl( return layernorm(x, weight, eps, return_rstd=return_rstd, return_mean=return_mean) def quack_layer_norm_forward_meta( - x: torch.Tensor, - weight: torch.Tensor, + x: TensorProxy, + weight: TensorProxy, eps: float, return_rstd: bool, return_mean: bool, @@ -320,10 +320,118 @@ def quack_layer_norm_transform( bias: TensorProxy | None = None, eps: Number = 1e-5, ) -> TensorProxy: - return quack_layer_norm_forward(a, weight, eps) + return quack_layer_norm_forward(a, weight, eps, return_rstd=False, return_mean=False) cutlass_dsl_ex.register_operator( ltorch.layer_norm, checker=quack_layer_norm_checker, execution_transform=quack_layer_norm_transform, ) + + # rmsnorm + from quack.rmsnorm import _rmsnorm_fwd, _rmsnorm_backward + + def quack_rms_norm_forward_impl( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + ) -> torch.Tensor: + return _rmsnorm_fwd(x, weight, eps, return_rstd=False) + + def quack_rms_norm_forward_meta( + x: TensorProxy, + weight: TensorProxy, + eps: float = 1e-6, + ) -> TensorProxy: + return TensorProxy(like=x) + + quack_rms_norm_forward = cutlass_dsl_ex.register_operator( + "cutlass_quack_rms_norm_forward", + meta=quack_rms_norm_forward_meta, + fn=quack_rms_norm_forward_impl, + ) + + def quack_rms_norm_backward_impl( + grad: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + rstd: torch.Tensor, + ) -> torch.Tensor: + return _rmsnorm_backward(x, weight, grad, rstd) + + def quack_rms_norm_backward_meta( + grad: TensorProxy, + x: TensorProxy, + weight: TensorProxy, + rstd: TensorProxy, + ) -> TensorProxy: + return TensorProxy(like=grad) + + quack_rms_norm_backward = cutlass_dsl_ex.register_operator( + "cutlass_quack_rms_norm_backward", + meta=quack_rms_norm_forward_meta, + fn=quack_rms_norm_backward_impl, + ) + + def quack_rms_norm_checker( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + eps: float | None = None, + ) -> bool: + return weight is not None and is_device_quack_compat() + + def quack_rms_norm_transform( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + eps: float | None = None, + ) -> TensorProxy: + if eps is None: + eps = 1e-6 + return quack_rms_norm_forward(a, weight, eps) + + def quack_rms_norm_aug_forward_impl( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + ) -> tuple[torch.Tensor, torch.Tensor]: + return _rmsnorm_fwd(x, weight, eps, return_rstd=True) + + def quack_rms_norm_aug_forward_meta( + x: TensorProxy, + weight: TensorProxy, + eps: float = 1e-6, + ) -> tuple[TensorProxy, TensorProxy]: + return (TensorProxy(like=x), TensorProxy(like=x, shape=(x.shape[0]))) + + quack_rms_norm_aug_forward = cutlass_dsl_ex.register_operator( + "cutlass_quack_rms_norm_aug_forward", + meta=quack_rms_norm_aug_forward_meta, + fn=quack_rms_norm_aug_forward_impl, + ) + + def quack_rms_norm_grad( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + eps: float | None = None, + ) -> TensorProxy: + if eps is None: + eps = 1e-6 + fwd, rstd = quack_rms_norm_aug_forward(a, weight, eps) + + grad = get_grad(fwd) + a_grad = quack_rms_norm_backward(grad, a, weight, rstd) + put_grad(a, a_grad) + return fwd + + cutlass_dsl_ex.register_implementation( + ltorch.rms_norm, + checker=quack_rms_norm_checker, + execution_transform=quack_rms_norm_transform, + grad_transform=quack_rms_norm_grad, + ) From 0b51a1280476842e4da9157fbb6ace04bb68404e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 18 Jul 2025 03:07:31 -0700 Subject: [PATCH 05/22] fix Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 206e02b23f..46b6e875ff 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -322,7 +322,7 @@ def quack_layer_norm_transform( ) -> TensorProxy: return quack_layer_norm_forward(a, weight, eps, return_rstd=False, return_mean=False) - cutlass_dsl_ex.register_operator( + cutlass_dsl_ex.register_implementation( ltorch.layer_norm, checker=quack_layer_norm_checker, execution_transform=quack_layer_norm_transform, @@ -369,7 +369,7 @@ def quack_rms_norm_backward_meta( quack_rms_norm_backward = cutlass_dsl_ex.register_operator( "cutlass_quack_rms_norm_backward", - meta=quack_rms_norm_forward_meta, + meta=quack_rms_norm_backward_meta, fn=quack_rms_norm_backward_impl, ) @@ -382,17 +382,6 @@ def quack_rms_norm_checker( ) -> bool: return weight is not None and is_device_quack_compat() - def quack_rms_norm_transform( - a: TensorProxy, - /, - normalized_shape: Sequence[int], - weight: TensorProxy | None = None, - eps: float | None = None, - ) -> TensorProxy: - if eps is None: - eps = 1e-6 - return quack_rms_norm_forward(a, weight, eps) - def quack_rms_norm_aug_forward_impl( x: torch.Tensor, weight: torch.Tensor, @@ -405,7 +394,7 @@ def quack_rms_norm_aug_forward_meta( weight: TensorProxy, eps: float = 1e-6, ) -> tuple[TensorProxy, TensorProxy]: - return (TensorProxy(like=x), TensorProxy(like=x, shape=(x.shape[0]))) + return (TensorProxy(like=x), TensorProxy(like=x, shape=(x.shape[0],))) quack_rms_norm_aug_forward = cutlass_dsl_ex.register_operator( "cutlass_quack_rms_norm_aug_forward", @@ -413,6 +402,17 @@ def quack_rms_norm_aug_forward_meta( fn=quack_rms_norm_aug_forward_impl, ) + def quack_rms_norm_transform( + a: TensorProxy, + /, + normalized_shape: Sequence[int], + weight: TensorProxy | None = None, + eps: float | None = None, + ) -> TensorProxy: + if eps is None: + eps = 1e-6 + return quack_rms_norm_aug_forward(a, weight, eps)[0] + def quack_rms_norm_grad( a: TensorProxy, /, From d769096ee3b0f748615933df8583772919a5f394 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 01:11:41 -0700 Subject: [PATCH 06/22] fix backward of crossentropy Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 46b6e875ff..4e76c0d7e8 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -185,11 +185,11 @@ def quack_cross_entropy_backward_meta( grad: TensorProxy, lse: TensorProxy, ) -> TensorProxy: - return TensorProxy(like=grad) + return TensorProxy(like=x) quack_cross_entropy_backward = cutlass_dsl_ex.register_operator( "cutlass_quack_cross_entropy_backward", - meta=quack_softmax_backward_meta, + meta=quack_cross_entropy_backward_meta, fn=quack_cross_entropy_backward_impl, ) From 4a0ced4a62cc820b8d7d14ef74ebcc4be2e55d68 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 01:36:16 -0700 Subject: [PATCH 07/22] fix checkers Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 37 ++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 4e76c0d7e8..21b7caaa3d 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -7,6 +7,7 @@ from looseversion import LooseVersion import torch +from thunder.core import dtypes from thunder.core.transforms import get_grad, put_grad from thunder.extend import register_executor, OperatorExecutor from thunder.core.proxies import TensorProxy @@ -111,6 +112,7 @@ def quack_softmax_backward_meta(g: TensorProxy, a: TensorProxy) -> TensorProxy: fn=quack_softmax_backward, ) + # Ref: https://github.com/Dao-AILab/quack/blob/3ce89a24/quack/softmax.py#L189-L198 def quack_softmax_checker( a: TensorProxy, /, @@ -120,7 +122,13 @@ def quack_softmax_checker( ) -> bool: last_dims = {-1, a.ndim - 1} allowed_dtypes = {None, a.dtype} - return dim in last_dims and dtype in allowed_dtypes and is_device_quack_compat() + return ( + a.ndim == 2 + and dim in last_dims + and dtype in allowed_dtypes + and a.dtype in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + and is_device_quack_compat() + ) def quack_softmax_transform( a: TensorProxy, @@ -193,6 +201,7 @@ def quack_cross_entropy_backward_meta( fn=quack_cross_entropy_backward_impl, ) + # Ref: https://github.com/Dao-AILab/quack/blob/3ce89a24/quack/cross_entropy.py#L216-L239 def quack_cross_entropy_checker( a: TensorProxy, /, @@ -220,6 +229,14 @@ def quack_cross_entropy_checker( if label_smoothing != 0.0: return False + if ( + a.ndim != 2 + or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + and target.ndim == 1 + and target.dytpe in {dtypes.int32, dtypes.int64} + ): + return False + return True def quack_cross_entropy_transform( @@ -302,6 +319,7 @@ def quack_layer_norm_forward_meta( fn=quack_layer_norm_forward_impl, ) + # Ref: https://github.com/Dao-AILab/quack/blob/3ce89a24/quack/layernorm.py#L252-L278 def quack_layer_norm_checker( a: TensorProxy, /, @@ -310,6 +328,14 @@ def quack_layer_norm_checker( bias: TensorProxy | None = None, eps: Number = 1e-5, ) -> bool: + if ( + a.ndim != 2 + or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + or weight.ndim != 1 + or a.shape[-1] != weight.shape[0] + or weight.dtype not in {dtypes.float32} + ): + return False return is_device_quack_compat() def quack_layer_norm_transform( @@ -373,6 +399,7 @@ def quack_rms_norm_backward_meta( fn=quack_rms_norm_backward_impl, ) + # Ref: https://github.com/Dao-AILab/quack/blob/3ce89a24/quack/rmsnorm.py#L231-L261 def quack_rms_norm_checker( a: TensorProxy, /, @@ -380,6 +407,14 @@ def quack_rms_norm_checker( weight: TensorProxy | None = None, eps: float | None = None, ) -> bool: + if ( + a.ndim != 2 + or weight.ndim != 1 + or a.shape[-1] != weight.shape[0] + or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + or weight.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + ): + return False return weight is not None and is_device_quack_compat() def quack_rms_norm_aug_forward_impl( From 1a2a868c13fd9f7a8113dacb01b50c364bcac58d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 21 Jul 2025 14:20:16 +0900 Subject: [PATCH 08/22] [no ci] add test Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 thunder/tests/test_cutlass_dsl_ex.py diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py new file mode 100644 index 0000000000..5b97b8ab86 --- /dev/null +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -0,0 +1,110 @@ +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import thunder +from thunder.executors.cutlass_dsl_ex import cutlass_dsl_ex, is_device_quack_compat +from thunder.tests.framework import requiresCUDA + +if TYPE_CHECKING: + from typing import Any + from collections.abc import Callable + + +_quack_available = find_spec("quack") is not None +quack_available = pytest.mark.skipif( + not is_device_quack_compat() or not _quack_available, + reason="quack requires SM9.0/10.0", +) + + +@pytest.fixture(autouse=True, scope="module") +def set_cuda_as_default_device(): + original_default_device: torch.device | None = None + if torch.cuda.is_available(): + original_default_device = torch.get_default_device() + torch.set_default_device("cuda") + yield + + # Teardown + if original_default_device is not None: + torch.set_default_device(original_default_device) + + +def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]: + return thunder.jit(fn, executors=[cutlass_dsl_ex]) + + +@requiresCUDA +@quack_available +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +def test_quack_cross_entropy(dtype: torch.dtype): + x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) + ref_x = x.clone().detach() + targets = torch.randint(0, 128, (128,), dtype=torch.int64) + + jitted = jit_with_cutlass_dsl_ex(F.cross_entropy) + + expected = F.cross_entropy(ref_x, targets, reduction="none") + actual = jitted(x, targets, reduction="none") + torch.testing.assert_close(expected, actual) + + expected_grad = torch.autograd.grad((expected,), (ref_x, targets)) + actual_grad = torch.autograd.grad((actual,), (x, targets)) + torch.testing.assert_close(expected_grad, actual_grad) + + +@requiresCUDA +@quack_available +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +def test_quack_softmax(dtype: torch.dtype): + x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) + ref_x = x.clone().detach() + + jitted = jit_with_cutlass_dsl_ex(F.softmax) + + expected = F.softmax(ref_x, dim=-1, reduction="none") + actual = jitted(x, dim=-1, reduction="none") + torch.testing.assert_close(expected, actual) + + expected_grad = torch.autograd.grad((expected,), (ref_x,)) + actual_grad = torch.autograd.grad((actual,), (x,)) + torch.testing.assert_close(expected_grad, actual_grad) + + +@requiresCUDA +@quack_available +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +def test_quack_layernorm(dtype: torch.dtype): + x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) + ref_x = x.clone().detach() + + module = nn.LayerNorm(1024).cuda() + jitted = jit_with_cutlass_dsl_ex(module) + + expected = module(ref_x) + actual = jitted(x) + torch.testing.assert_close(expected, actual) + + +@requiresCUDA +@quack_available +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +def test_quack_rmsrnorm(dtype: torch.dtype): + x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) + ref_x = x.clone().detach() + + module = nn.RMSNorm(1024).cuda() + jitted = jit_with_cutlass_dsl_ex(module) + + expected = module(ref_x) + actual = jitted(x) + torch.testing.assert_close(expected, actual) + + expected_grad = torch.autograd.grad((expected,), (ref_x,)) + actual_grad = torch.autograd.grad((actual,), (x,)) + torch.testing.assert_close(expected_grad, actual_grad) From d6efb9a69277eb5c47c839fdf1924940d43ae12a Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 22:44:58 -0700 Subject: [PATCH 09/22] DRY: dtypes & their ids Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index 5b97b8ab86..c2a68fc3a5 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -1,3 +1,4 @@ +from __future__ import annotations from importlib.util import find_spec from typing import TYPE_CHECKING @@ -20,6 +21,8 @@ not is_device_quack_compat() or not _quack_available, reason="quack requires SM9.0/10.0", ) +_DTYPES = (torch.float16, torch.bfloat16, torch.float32) +_DTYPE_IDS = tuple(str(a) for a in _DTYPES) @pytest.fixture(autouse=True, scope="module") @@ -41,7 +44,7 @@ def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]: @requiresCUDA @quack_available -@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_cross_entropy(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() @@ -60,7 +63,7 @@ def test_quack_cross_entropy(dtype: torch.dtype): @requiresCUDA @quack_available -@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_softmax(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() @@ -78,7 +81,7 @@ def test_quack_softmax(dtype: torch.dtype): @requiresCUDA @quack_available -@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_layernorm(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() @@ -93,7 +96,7 @@ def test_quack_layernorm(dtype: torch.dtype): @requiresCUDA @quack_available -@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) +@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_rmsrnorm(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() From c7bbf3461ac9d267a1da8453b4a149420cae509e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 22:45:48 -0700 Subject: [PATCH 10/22] comment out backward for now Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index c2a68fc3a5..58806131da 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -56,9 +56,9 @@ def test_quack_cross_entropy(dtype: torch.dtype): actual = jitted(x, targets, reduction="none") torch.testing.assert_close(expected, actual) - expected_grad = torch.autograd.grad((expected,), (ref_x, targets)) - actual_grad = torch.autograd.grad((actual,), (x, targets)) - torch.testing.assert_close(expected_grad, actual_grad) + # expected_grad = torch.autograd.grad((expected,), (ref_x, targets), ) + # actual_grad = torch.autograd.grad((actual,), (x, targets)) + # torch.testing.assert_close(expected_grad, actual_grad) @requiresCUDA @@ -74,9 +74,9 @@ def test_quack_softmax(dtype: torch.dtype): actual = jitted(x, dim=-1, reduction="none") torch.testing.assert_close(expected, actual) - expected_grad = torch.autograd.grad((expected,), (ref_x,)) - actual_grad = torch.autograd.grad((actual,), (x,)) - torch.testing.assert_close(expected_grad, actual_grad) + # expected_grad = torch.autograd.grad((expected,), (ref_x,)) + # actual_grad = torch.autograd.grad((actual,), (x,)) + # torch.testing.assert_close(expected_grad, actual_grad) @requiresCUDA @@ -108,6 +108,6 @@ def test_quack_rmsrnorm(dtype: torch.dtype): actual = jitted(x) torch.testing.assert_close(expected, actual) - expected_grad = torch.autograd.grad((expected,), (ref_x,)) - actual_grad = torch.autograd.grad((actual,), (x,)) - torch.testing.assert_close(expected_grad, actual_grad) + # expected_grad = torch.autograd.grad((expected,), (ref_x,)) + # actual_grad = torch.autograd.grad((actual,), (x,)) + # torch.testing.assert_close(expected_grad, actual_grad) From 4daff6e8c9a970229b561f04a2244ff19a50d1fd Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 22:51:07 -0700 Subject: [PATCH 11/22] upcast inputs to fp32 for reference it seems that quack's cross-entropy function upcasts inputs to fp32, thus updating test and meta function Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 2 +- thunder/tests/test_cutlass_dsl_ex.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index 21b7caaa3d..f61917cd9d 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -171,7 +171,7 @@ def quack_cross_entropy_forward_impl( return _cross_entropy(x, target, return_lse=False) def quack_cross_entropy_forward_meta(x: TensorProxy, target: TensorProxy) -> TensorProxy: - return TensorProxy(like=x, shape=(x.shape[0],)) + return TensorProxy(like=x, shape=(x.shape[0],), dtype=dtypes.float32) quack_cross_entropy_forward = cutlass_dsl_ex.register_operator( "cutlass_quack_cross_entropy_forward", diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index 58806131da..ae139300f5 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -48,6 +48,8 @@ def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]: def test_quack_cross_entropy(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() + if dtype != torch.float32: + ref_x = ref_x.to(torch.float32) targets = torch.randint(0, 128, (128,), dtype=torch.int64) jitted = jit_with_cutlass_dsl_ex(F.cross_entropy) From 37114975ef8c7190e84874b8f7010f1c3cf244ef Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 23:12:39 -0700 Subject: [PATCH 12/22] fix how softmax is called Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index ae139300f5..2e3094afc7 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -72,8 +72,8 @@ def test_quack_softmax(dtype: torch.dtype): jitted = jit_with_cutlass_dsl_ex(F.softmax) - expected = F.softmax(ref_x, dim=-1, reduction="none") - actual = jitted(x, dim=-1, reduction="none") + expected = F.softmax(ref_x, dim=-1) + actual = jitted(x, dim=-1) torch.testing.assert_close(expected, actual) # expected_grad = torch.autograd.grad((expected,), (ref_x,)) From 85a8e6592ce3fb9cdbe95274899fa0f226978b19 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 23:15:30 -0700 Subject: [PATCH 13/22] upcast and downcast for reference layernorm Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index 2e3094afc7..48db48d63f 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -86,12 +86,12 @@ def test_quack_softmax(dtype: torch.dtype): @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_layernorm(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) - ref_x = x.clone().detach() + ref_x = x.clone().detach().to(torch.float32) module = nn.LayerNorm(1024).cuda() jitted = jit_with_cutlass_dsl_ex(module) - expected = module(ref_x) + expected = module(ref_x).to(dtype) actual = jitted(x) torch.testing.assert_close(expected, actual) From 1b633d5ea415e5644a57747bb986286a7eefacd8 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 23:17:28 -0700 Subject: [PATCH 14/22] fix typo of rmsnorm Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index 48db48d63f..a532f1a8ab 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -99,7 +99,7 @@ def test_quack_layernorm(dtype: torch.dtype): @requiresCUDA @quack_available @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) -def test_quack_rmsrnorm(dtype: torch.dtype): +def test_quack_rmsnorm(dtype: torch.dtype): x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) ref_x = x.clone().detach() From 27ff1590d372a183cab31b05e6c0b75a9a9208b2 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 20 Jul 2025 23:28:20 -0700 Subject: [PATCH 15/22] fix meta Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index f61917cd9d..b54fe050d4 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -259,7 +259,10 @@ def quack_cross_entropy_aug_forward_impl( return _cross_entropy(x, target, return_lse=True) def quack_cross_entropy_aug_forward_meta(a: TensorProxy, target: TensorProxy) -> tuple[TensorProxy, TensorProxy]: - return (TensorProxy(like=a, shape=(a.shape[0],)), TensorProxy(like=a, shape=(a.shape[0],))) + return ( + TensorProxy(like=a, shape=(a.shape[0],), dtype=dtypes.float32), + TensorProxy(like=a, shape=(a.shape[0],), dtype=dtypes.float32), + ) quack_cross_entropy_aug_forward = cutlass_dsl_ex.register_operator( "cutlass_quack_cross_entropy_aug_forward", @@ -429,7 +432,7 @@ def quack_rms_norm_aug_forward_meta( weight: TensorProxy, eps: float = 1e-6, ) -> tuple[TensorProxy, TensorProxy]: - return (TensorProxy(like=x), TensorProxy(like=x, shape=(x.shape[0],))) + return (TensorProxy(like=x), TensorProxy(like=x, shape=(x.shape[0],), dtype=dtypes.float32)) quack_rms_norm_aug_forward = cutlass_dsl_ex.register_operator( "cutlass_quack_rms_norm_aug_forward", From 0958dbaa8f09dfb414272134ea22dfdf10d0e2e7 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 21 Jul 2025 00:45:09 -0700 Subject: [PATCH 16/22] add cutlass_dsl_ex to all_executors Signed-off-by: Masaki Kozuki --- thunder/extend/__init__.py | 1 + thunder/tests/test_extend.py | 1 + 2 files changed, 2 insertions(+) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 020345e51a..c2ff221d85 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -533,6 +533,7 @@ def get_all_executors() -> tuple[Executor, ...]: apexex, cudnn_layernormex, cudnnex, + cutlass_dsl_ex, nvfuserex, pythonex, sdpaex, diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 13b8d50ef1..825de24f7b 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -127,6 +127,7 @@ def test_get_all_executors_includes_all_native_executors(): expected = { "apex", "custom_op", + "cutlass_dsl", "fa3", "torch", "sdpa", From 668087858734b57f36df4b1ee1923d5ff01ad153 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 23 Jul 2025 05:43:21 -0700 Subject: [PATCH 17/22] Only forward, no backward support for now Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 130 +++++++++++++++++++++------ thunder/tests/test_cutlass_dsl_ex.py | 35 +++----- 2 files changed, 116 insertions(+), 49 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index b54fe050d4..c2d6fbed27 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -83,13 +83,24 @@ def is_device_quack_compat() -> bool: return torch.cuda.get_device_capability() in ((9, 0), (10, 0)) +# NOTE: This constraint comes from https://github.com/Dao-AILab/quack/blob/59631e98/quack/reduction_base.py#L35-L38 +def is_last_dim_divisible(dtype: dtypes.dtype, last_dim_size: int) -> bool: + return last_dim_size % (128 // 8 // dtype.bytes) == 0 + + # Register [`quack`](https://github.com/Dao-AILab/quack) ops if find_spec("quack") is not None: # softmax from quack.softmax import _softmax_fwd, _softmax_backward def quack_softmax_impl(a: torch.Tensor) -> torch.Tensor: - return _softmax_fwd(a) + original_shape = a.shape + if requires_reshpae := a.ndim > 2: + a = a.view(-1, original_shape[-1]) + ret = _softmax_fwd(a) + if requires_reshpae: + ret = ret.view(original_shape) + return ret def quack_softmax_meta(a: TensorProxy) -> TensorProxy: return TensorProxy(like=a) @@ -101,7 +112,14 @@ def quack_softmax_meta(a: TensorProxy) -> TensorProxy: ) def quack_softmax_backward(g: torch.Tensor, a: torch.Tensor) -> torch.Tensor: - return _softmax_backward(g, a) + original_shape = g.shape + if requires_reshape := g.ndim > 2: + g = g.view(-1, original_shape[-1]) + a = a.view(-1, original_shape[-1]) + ret = _softmax_backward(g, a) + if requires_reshape: + ret = ret.view(original_shape) + return ret def quack_softmax_backward_meta(g: TensorProxy, a: TensorProxy) -> TensorProxy: return TensorProxy(like=g) @@ -123,11 +141,11 @@ def quack_softmax_checker( last_dims = {-1, a.ndim - 1} allowed_dtypes = {None, a.dtype} return ( - a.ndim == 2 - and dim in last_dims + dim in last_dims and dtype in allowed_dtypes and a.dtype in {dtypes.float16, dtypes.bfloat16, dtypes.float32} and is_device_quack_compat() + and is_last_dim_divisible(a.dtype, a.shape[-1]) ) def quack_softmax_transform( @@ -139,26 +157,42 @@ def quack_softmax_transform( ) -> TensorProxy: return quack_softmax(a) - def quack_softmax_grad( - a: TensorProxy, - /, - dim: int, - *, - dtype: thunder_dtype | None = None, - ) -> TensorProxy: - fwd = quack_softmax(a) - g = get_grad(fwd) - a_grad = quack_softmax_backward(g, fwd) - put_grad(a, a_grad) - - return fwd + # NOTE: Softmax backward doesn't look functioning as follows: + # def _engine_run_backward( + # t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + # *args: Any, + # **kwargs: Any, + # ) -> tuple[torch.Tensor, ...]: + # attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + # if attach_logging_hooks: + # unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + # try: + # > return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + # t_outputs, *args, **kwargs + # ) # Calls into the C++ engine to run the backward pass + # E RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn + # + # /pytorch/torch/autograd/graph.py:829: RuntimeError + # def quack_softmax_grad( + # a: TensorProxy, + # /, + # dim: int, + # *, + # dtype: thunder_dtype | None = None, + # ) -> TensorProxy: + # fwd = quack_softmax(a) + # g = get_grad(fwd) + # a_grad = quack_softmax_backward(g, fwd) + # put_grad(a, a_grad) + + # return fwd for ltorch_softmax in (ltorch._softmax, ltorch.softmax): cutlass_dsl_ex.register_implementation( ltorch_softmax, checker=quack_softmax_checker, execution_transform=quack_softmax_transform, - grad_transform=quack_softmax_grad, + # grad_transform=quack_softmax_grad, ) # crossentropy @@ -305,7 +339,13 @@ def quack_layer_norm_forward_impl( return_rstd: bool, return_mean: bool, ) -> torch.Tensor: - return layernorm(x, weight, eps, return_rstd=return_rstd, return_mean=return_mean) + original_shape = x.shape + if requires_reshape := x.ndim > 2: + x = x.view(-1, original_shape[-1]) + ret = layernorm(x, weight, eps, return_rstd=return_rstd, return_mean=return_mean) + if requires_reshape: + ret = ret.view(original_shape) + return ret def quack_layer_norm_forward_meta( x: TensorProxy, @@ -332,8 +372,7 @@ def quack_layer_norm_checker( eps: Number = 1e-5, ) -> bool: if ( - a.ndim != 2 - or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} + a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} or weight.ndim != 1 or a.shape[-1] != weight.shape[0] or weight.dtype not in {dtypes.float32} @@ -365,7 +404,13 @@ def quack_rms_norm_forward_impl( weight: torch.Tensor, eps: float = 1e-6, ) -> torch.Tensor: - return _rmsnorm_fwd(x, weight, eps, return_rstd=False) + original_shape = x.shape + if requires_reshape := x.ndim > 2: + x = x.view(-1, original_shape[-1]) + ret = _rmsnorm_fwd(x, weight, eps, return_rstd=False) + if requires_reshape: + ret = ret.view(original_shape) + return ret def quack_rms_norm_forward_meta( x: TensorProxy, @@ -386,7 +431,14 @@ def quack_rms_norm_backward_impl( weight: torch.Tensor, rstd: torch.Tensor, ) -> torch.Tensor: - return _rmsnorm_backward(x, weight, grad, rstd) + original_shape = grad.shape + if requires_reshape := grad.ndim > 2: + grad = grad.view(-1, original_shape[-1]) + x = x.view(-1, original_shape[-1]) + ret = _rmsnorm_backward(x, weight, grad, rstd) + if requires_reshape: + ret = ret.view(original_shape) + return ret def quack_rms_norm_backward_meta( grad: TensorProxy, @@ -411,21 +463,26 @@ def quack_rms_norm_checker( eps: float | None = None, ) -> bool: if ( - a.ndim != 2 - or weight.ndim != 1 + weight.ndim != 1 or a.shape[-1] != weight.shape[0] or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} or weight.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} ): return False - return weight is not None and is_device_quack_compat() + return weight is not None and is_device_quack_compat() and is_last_dim_divisible(a.dtype, a.shape[-1]) def quack_rms_norm_aug_forward_impl( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor]: - return _rmsnorm_fwd(x, weight, eps, return_rstd=True) + original_shape = x.shape + if requires_reshape := x.ndim > 2: + x = x.view(-1, original_shape[-1]) + fwd, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True) + if requires_reshape: + fwd = fwd.view(original_shape) + return fwd, rstd def quack_rms_norm_aug_forward_meta( x: TensorProxy, @@ -451,6 +508,23 @@ def quack_rms_norm_transform( eps = 1e-6 return quack_rms_norm_aug_forward(a, weight, eps)[0] + # NOTE: The backward looks not functioning: + # def _engine_run_backward( + # t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + # *args: Any, + # **kwargs: Any, + # ) -> tuple[torch.Tensor, ...]: + # attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + # if attach_logging_hooks: + # unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + # try: + # > return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + # t_outputs, *args, **kwargs + # ) # Calls into the C++ engine to run the backward pass + # E RuntimeError: One of the differentiated Tensors does not require grad + # + # /pytorch/torch/autograd/graph.py:829: RuntimeError + def quack_rms_norm_grad( a: TensorProxy, /, @@ -471,5 +545,5 @@ def quack_rms_norm_grad( ltorch.rms_norm, checker=quack_rms_norm_checker, execution_transform=quack_rms_norm_transform, - grad_transform=quack_rms_norm_grad, + # grad_transform=quack_rms_norm_grad, ) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index a532f1a8ab..d787510ef1 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -23,6 +23,8 @@ ) _DTYPES = (torch.float16, torch.bfloat16, torch.float32) _DTYPE_IDS = tuple(str(a) for a in _DTYPES) +_SHAPES = ((128, 1024), (3, 139, 641), (3, 3, 128, 1024)) +_SHAPE_IDS = ("2d", "incompat_3d", "3d") @pytest.fixture(autouse=True, scope="module") @@ -39,7 +41,7 @@ def set_cuda_as_default_device(): def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]: - return thunder.jit(fn, executors=[cutlass_dsl_ex]) + return thunder.jit(fn, executors=[cutlass_dsl_ex], disable_torch_autograd=True) @requiresCUDA @@ -58,16 +60,13 @@ def test_quack_cross_entropy(dtype: torch.dtype): actual = jitted(x, targets, reduction="none") torch.testing.assert_close(expected, actual) - # expected_grad = torch.autograd.grad((expected,), (ref_x, targets), ) - # actual_grad = torch.autograd.grad((actual,), (x, targets)) - # torch.testing.assert_close(expected_grad, actual_grad) - @requiresCUDA @quack_available +@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) -def test_quack_softmax(dtype: torch.dtype): - x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) +def test_quack_softmax(dtype: torch.dtype, shape: tuple[int, ...]): + x = torch.randn(shape, dtype=dtype, requires_grad=True) ref_x = x.clone().detach() jitted = jit_with_cutlass_dsl_ex(F.softmax) @@ -76,19 +75,16 @@ def test_quack_softmax(dtype: torch.dtype): actual = jitted(x, dim=-1) torch.testing.assert_close(expected, actual) - # expected_grad = torch.autograd.grad((expected,), (ref_x,)) - # actual_grad = torch.autograd.grad((actual,), (x,)) - # torch.testing.assert_close(expected_grad, actual_grad) - @requiresCUDA @quack_available +@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) -def test_quack_layernorm(dtype: torch.dtype): - x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) +def test_quack_layernorm(dtype: torch.dtype, shape: tuple[int, ...]): + x = torch.randn(shape, dtype=dtype, requires_grad=True) ref_x = x.clone().detach().to(torch.float32) - module = nn.LayerNorm(1024).cuda() + module = nn.LayerNorm(shape[-1]).cuda() jitted = jit_with_cutlass_dsl_ex(module) expected = module(ref_x).to(dtype) @@ -98,18 +94,15 @@ def test_quack_layernorm(dtype: torch.dtype): @requiresCUDA @quack_available +@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) -def test_quack_rmsnorm(dtype: torch.dtype): - x = torch.randn((128, 1024), dtype=dtype, requires_grad=True) +def test_quack_rmsnorm(dtype: torch.dtype, shape: tuple[int, ...]): + x = torch.randn(shape, dtype=dtype, requires_grad=True) ref_x = x.clone().detach() - module = nn.RMSNorm(1024).cuda() + module = nn.RMSNorm(shape[-1]).cuda() jitted = jit_with_cutlass_dsl_ex(module) expected = module(ref_x) actual = jitted(x) torch.testing.assert_close(expected, actual) - - # expected_grad = torch.autograd.grad((expected,), (ref_x,)) - # actual_grad = torch.autograd.grad((actual,), (x,)) - # torch.testing.assert_close(expected_grad, actual_grad) From b9c876dc8e29b735a7f9faee6aabb201420d480f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 6 Aug 2025 01:31:07 -0700 Subject: [PATCH 18/22] call non-augmented forward in execution transform Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index c2d6fbed27..a08c821bca 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -506,7 +506,7 @@ def quack_rms_norm_transform( ) -> TensorProxy: if eps is None: eps = 1e-6 - return quack_rms_norm_aug_forward(a, weight, eps)[0] + return quack_rms_norm_forward(a, weight, eps) # NOTE: The backward looks not functioning: # def _engine_run_backward( From 340f14e312428164a091df80aada20f5698d86b6 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 6 Aug 2025 02:30:04 -0700 Subject: [PATCH 19/22] quack bench Signed-off-by: Masaki Kozuki --- thunder/benchmarks/targets.py | 202 ++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 636dcb8f9b..62dc9d42ad 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -33,10 +33,12 @@ HFBenchmark, LinearLoRABenchmark, DeepSeekSGLangMoEBenchmark, + UserFacingBenchmarkMeta, thunder_apex_executor, thunder_apex_nvfuser_executor, thunder_cudnn_executor, thunder_cudnn_nvfuser_executor, + thunder_cudnn_layer_norm_executor, thunder_executor, thunderfx_executor, thunder_sdpa_torch_compile_nvfuser_executor, @@ -1028,3 +1030,203 @@ def test_optim_functional( args, kwargs = bench.make_batch() benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs) + + +def cutlass_dsl_ex_executor(fn: Callable) -> Callable: + from thunder.executors.cutlass_dsl_ex import cutlass_dsl_ex + + torch.backends.cuda.matmul.allow_tf32 = True + return thunder.jit(fn, disable_torch_autograd=True, executors=[cutlass_dsl_ex]) + + +def nvfuserex_executor(fn: Callable) -> Callable: + from thunder.executors.nvfuserex import nvfuserex + + torch.backends.cuda.matmul.allow_tf32 = True + return thunder.jit(fn, executors=[nvfuserex]) + + +class BaseBenchmarkForQuack(Benchmark, metaclass=UserFacingBenchmarkMeta): + from thunder.benchmarks import BenchmarkArg + + _args = ( + BenchmarkArg("shape", description="The shape of the input tensor"), + BenchmarkArg("dtype", description="The dtype of the input tensor"), + BenchmarkArg("fn", description="The function to benchmark"), + ) + + def __init__(self, shape: tuple[int, int], fn: Callable): + self.shape = shape + self._fn = fn + + @property + def description(self) -> str: + return f"Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor" + + +class CrossEntropyBenchmarkForQuack(BaseBenchmarkForQuack): + def __init__(self, shape: tuple[int, int], dtype: torch.dtype): + super().__init__(shape, torch.nn.functional.cross_entropy) + self.dtype = dtype + + def make_batch(self) -> tuple[list, dict]: + return [ + torch.randn(self.shape, device="cuda", dtype=self.dtype), + torch.randint(0, 16, (self.shape[0],), device="cuda"), + ], {} + + @property + def name(self) -> str: + return f"CrossEntropyBenchmarkForQuack({self.shape})" + + def fn(self) -> Callable: + def f(*args): + return self._fn(*args, reduction="none") + + return f + + +class SoftmaxBenchmarkForQuack(BaseBenchmarkForQuack): + def __init__(self, shape: tuple[int, int], dtype: torch.dtype): + super().__init__(shape, torch.nn.functional.softmax) + self.dtype = dtype + + def make_batch(self) -> tuple[list, dict]: + return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {} + + @property + def name(self) -> str: + return f"SoftmaxBenchmarkForQuack({self.shape})" + + def fn(self) -> Callable: + def f(*args): + return self._fn(*args, dim=-1) + + return f + + +class LayerNormBenchmarkForQuack(BaseBenchmarkForQuack): + def __init__(self, shape: tuple[int, int], dtype: torch.dtype): + import torch.nn as nn + + super().__init__(shape, torch.nn.functional.layer_norm) + self.dtype = dtype + self.layer = nn.LayerNorm(self.shape[1]).to(device="cuda", dtype=self.dtype) + + def make_batch(self) -> tuple[list, dict]: + return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {} + + @property + def name(self) -> str: + return f"LayerNormBenchmarkForQuack({self.shape})" + + def fn(self) -> Callable: + def f(*args): + return self.layer(*args) + + return f + + +class RMSNormBenchmarkForQuack(BaseBenchmarkForQuack): + def __init__(self, shape: tuple[int, int], dtype: torch.dtype): + import torch.nn as nn + + super().__init__(shape, torch.nn.functional.rms_norm) + self.dtype = dtype + self.layer = nn.RMSNorm(self.shape[1]).to(device="cuda", dtype=self.dtype) + + def make_batch(self) -> tuple[list, dict]: + return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {} + + @property + def name(self) -> str: + return f"RMSNormBenchmarkForQuack({self.shape})" + + def fn(self) -> Callable: + def f(*args): + return self.layer(*args) + + return f + + +# Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor +# Input shapes (M, N) should cover the following cases +quack_bench_executors = ( + cutlass_dsl_ex_executor, + nvfuserex_executor, + torch_compile_executor, +) +quack_bench_shapes = ( + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (32768, 131072), + (32768, 262144), + (8192, 512), + (8192, 1024), + (8192, 2048), + (8192, 4096), + (8192, 8192), + (8192, 16384), + (8192, 32768), + (8192, 65536), + (8192, 131072), + (8192, 262144), +) +quack_bench_shape_ids = [f"{m}_{n}" for m, n in quack_bench_shapes] +dtypes = ( + torch.float32, + torch.bfloat16, + torch.float16, +) + + +def _run_benchmark_for_quack( + benchmark, executor, benchmark_cls, dtype, shape: tuple[int, int], compute_type: ComputeType +): + bench = benchmark_cls(shape, dtype) + args, kwargs = bench.make_batch() + fn = executor(bench.fn()) + benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs) + + +@pytest.mark.parametrize("executor", quack_bench_executors) +@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes)) +@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids) +@parametrize_compute_type_only_inference +def test_benchmark_quack_cross_entropy(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType): + _run_benchmark_for_quack(benchmark, executor, CrossEntropyBenchmarkForQuack, dtype, shape, compute_type) + + +@pytest.mark.parametrize("executor", quack_bench_executors) +@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes)) +@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids) +@parametrize_compute_type_only_inference +def test_benchmark_quack_softmax(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType): + _run_benchmark_for_quack(benchmark, executor, SoftmaxBenchmarkForQuack, dtype, shape, compute_type) + + +quack_layer_norm_executors = quack_bench_executors +if thunder_cudnn_layer_norm_executor is not None: + quack_layer_norm_executors += (thunder_cudnn_layer_norm_executor,) + + +@pytest.mark.parametrize("executor", quack_layer_norm_executors) +@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes)) +@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids) +@parametrize_compute_type_only_inference +def test_benchmark_quack_layer_norm(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType): + _run_benchmark_for_quack(benchmark, executor, LayerNormBenchmarkForQuack, dtype, shape, compute_type) + + +@pytest.mark.parametrize("executor", quack_bench_executors) +@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes)) +@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids) +@parametrize_compute_type_only_inference +def test_benchmark_quack_rms_norm(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType): + _run_benchmark_for_quack(benchmark, executor, RMSNormBenchmarkForQuack, dtype, shape, compute_type) From 968420c02e13f0ba6301d376d6550304afee745c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 13:14:59 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/benchmarks/targets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 62dc9d42ad..2e8f3938b3 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -1061,7 +1061,7 @@ def __init__(self, shape: tuple[int, int], fn: Callable): @property def description(self) -> str: - return f"Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor" + return "Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor" class CrossEntropyBenchmarkForQuack(BaseBenchmarkForQuack): From b4408d1adaafafa3e09ae4075c42e48c566ba8f8 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 3 Sep 2025 13:03:48 +0900 Subject: [PATCH 21/22] fix quack availability check Signed-off-by: Masaki Kozuki --- thunder/tests/test_cutlass_dsl_ex.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/thunder/tests/test_cutlass_dsl_ex.py b/thunder/tests/test_cutlass_dsl_ex.py index d787510ef1..2f2545a98e 100644 --- a/thunder/tests/test_cutlass_dsl_ex.py +++ b/thunder/tests/test_cutlass_dsl_ex.py @@ -9,7 +9,6 @@ import thunder from thunder.executors.cutlass_dsl_ex import cutlass_dsl_ex, is_device_quack_compat -from thunder.tests.framework import requiresCUDA if TYPE_CHECKING: from typing import Any @@ -18,7 +17,7 @@ _quack_available = find_spec("quack") is not None quack_available = pytest.mark.skipif( - not is_device_quack_compat() or not _quack_available, + not torch.cuda.is_available() or not _quack_available or not is_device_quack_compat(), reason="quack requires SM9.0/10.0", ) _DTYPES = (torch.float16, torch.bfloat16, torch.float32) @@ -44,7 +43,6 @@ def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]: return thunder.jit(fn, executors=[cutlass_dsl_ex], disable_torch_autograd=True) -@requiresCUDA @quack_available @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) def test_quack_cross_entropy(dtype: torch.dtype): @@ -61,7 +59,6 @@ def test_quack_cross_entropy(dtype: torch.dtype): torch.testing.assert_close(expected, actual) -@requiresCUDA @quack_available @pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) @@ -76,7 +73,6 @@ def test_quack_softmax(dtype: torch.dtype, shape: tuple[int, ...]): torch.testing.assert_close(expected, actual) -@requiresCUDA @quack_available @pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) @@ -92,7 +88,6 @@ def test_quack_layernorm(dtype: torch.dtype, shape: tuple[int, ...]): torch.testing.assert_close(expected, actual) -@requiresCUDA @quack_available @pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS) @pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS) From 1e5f3b2d709a15d7b625f1c651cfbbb5cb55a844 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 11 Nov 2025 03:36:07 +0900 Subject: [PATCH 22/22] mandate weight in layer|rms norm Signed-off-by: Masaki Kozuki --- thunder/executors/cutlass_dsl_ex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/executors/cutlass_dsl_ex.py b/thunder/executors/cutlass_dsl_ex.py index a08c821bca..e60163eba1 100644 --- a/thunder/executors/cutlass_dsl_ex.py +++ b/thunder/executors/cutlass_dsl_ex.py @@ -373,7 +373,7 @@ def quack_layer_norm_checker( ) -> bool: if ( a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} - or weight.ndim != 1 + or (weight is None or weight.ndim != 1) or a.shape[-1] != weight.shape[0] or weight.dtype not in {dtypes.float32} ): @@ -463,7 +463,7 @@ def quack_rms_norm_checker( eps: float | None = None, ) -> bool: if ( - weight.ndim != 1 + (weight is None or weight.ndim != 1) or a.shape[-1] != weight.shape[0] or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} or weight.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}