Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions thunder/executors/tilegymex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import warnings

from lightning_utilities.core.imports import package_available

from thunder import Transform
from thunder.extend import OperatorExecutor

__all__ = ["tilegym_ex", "TileGymTransform"]


tilegym_ex: None | OperatorExecutor = None
TileGymTransform: None | Transform = None


if package_available("tilegym"):
import thunder.executors.tilegymex_impl as impl

tilegym_ex = impl.tilegym_ex
TileGymTransform = impl.TileGymTransform
else:
warnings.warn("tilegym module not found!")
254 changes: 254 additions & 0 deletions thunder/executors/tilegymex_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

import torch
from lightning_utilities.core.imports import package_available

import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
from thunder.core.proxies import pyval
from thunder.extend import OperatorExecutor, register_executor
from thunder import Transform
import thunder.torch as ltorch

if TYPE_CHECKING:
from thunder.torch import TensorLike


if not package_available("tilegym"):
raise ImportError("tilegym is required for the tilegym executor")

import tilegym
from tilegym import ops as tg_ops


tilegym_ex: OperatorExecutor = OperatorExecutor("tilegym", version=getattr(tilegym, "__version__", None))
register_executor(tilegym_ex)


def _is_cuda_tensor(t: TensorLike) -> bool:
return t.device.devicetype == devices.DeviceType.CUDA


def _pybool(x) -> bool:
try:
return bool(pyval(x))
except Exception:
return False


def _pyfloat_or_none(x) -> float | None:
if x is None:
return None
try:
return float(pyval(x))
except Exception:
return None


def _parse_min_cc(s: str) -> tuple[int, int] | None:
# Accept "10.0", "10,0", or "100" (treated as "10.0").
s = (s or "").strip()
if not s:
return None
if "." in s:
a, b = s.split(".", 1)
return int(a), int(b)
if "," in s:
a, b = s.split(",", 1)
return int(a), int(b)
if s.isdigit():
if len(s) >= 2:
return int(s[:-1]), int(s[-1])
return int(s), 0
return None


def _tilegym_device_cc_ok(device_index: int) -> bool:
# Default to Blackwell+ (SM100). Override via env vars:
# - THUNDER_TILEGYM_ALLOW_ANY_CC=1 (bypass)
# - THUNDER_TILEGYM_MIN_CC=10.0 (set minimum)
if os.environ.get("THUNDER_TILEGYM_ALLOW_ANY_CC", "0").lower() in ("1", "true", "yes", "y", "on"):
return True

min_cc = _parse_min_cc(os.environ.get("THUNDER_TILEGYM_MIN_CC", "10.0"))
if min_cc is None:
min_cc = (10, 0)

if not torch.cuda.is_available():
return False
try:
cc = torch.cuda.get_device_capability(device_index)
except Exception:
return False

return tuple(cc) >= tuple(min_cc)


def _tilegym_sdpa_checker(
query: TensorLike,
key: TensorLike,
value: TensorLike,
attn_mask: TensorLike | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: float | None = None,
) -> bool:
# TileGym kernels are CUDA-only.
if not (_is_cuda_tensor(query) and _is_cuda_tensor(key) and _is_cuda_tensor(value)):
return False

if not _tilegym_device_cc_ok(query.device.index):
return False

if key.device != query.device or value.device != query.device:
return False

# TileGym kernels currently don't support explicit masks or dropout.
if attn_mask is not None:
return False

try:
dropout_p_val = float(pyval(dropout_p))
except Exception:
return False
if dropout_p_val != 0.0:
return False

is_causal_val = _pybool(is_causal)

# TileGym attention kernels don't implement backward yet.
if query.requires_grad or key.requires_grad or value.requires_grad:
return False

# Expected shapes: (B, H, S, D)
if query.ndim != 4 or key.ndim != 4 or value.ndim != 4:
return False

bq, hq, sq, dq = query.shape
bk, hk, sk, dk = key.shape
bv, hv, sv, dv = value.shape

if bq != bk or bq != bv:
return False
if hq != hk or hq != hv:
# Thunder/torch SDPA expects same number of heads
return False
if sk != sv:
return False
if dq != dk or dq != dv:
# TileGym fmha expects Dq == Dk == Dv
return False

# TileGym decode kernel assumes non-causal semantics for q_len==1 and k_len>1.
if sq == 1 and sk > 1 and is_causal_val:
return False

# TileGym prefill causal assumes query positions start at 0 and align with keys.
if is_causal_val and sq != sk:
return False

# D requirements: TensorCore-friendly.
if dq % 8 != 0:
return False

# Dtype requirements (TileGym kernels use MMA paths).
if query.dtype not in (dtypes.float16, dtypes.bfloat16):
return False
if key.dtype != query.dtype or value.dtype != query.dtype:
return False

# If scale is symbolic/unknown, we can still run (TileGym defaults to 1/sqrt(D)).
_ = _pyfloat_or_none(scale)

return True


def _tilegym_sdpa_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: float | None = None,
) -> torch.Tensor:
# Checker guarantees attn_mask is None and dropout_p == 0.0.
if query.shape[2] == 1 and key.shape[2] > 1:
# Decode kernel (non-causal semantics expected; checker enforces that)
return tg_ops.fmha_decode(query, key, value, sm_scale=scale)
return tg_ops.fmha(query, key, value, scaling=scale, is_causal=is_causal)


tilegym_sdpa = tilegym_ex.register_operator(
"tilegym_scaled_dot_product_attention",
like=ltorch.scaled_dot_product_attention,
fn=_tilegym_sdpa_impl,
)

tilegym_ex.register_implementation(
ltorch.scaled_dot_product_attention,
op=tilegym_sdpa,
checker=_tilegym_sdpa_checker,
)


def _tilegym_rms_norm_checker(
a: TensorLike,
normalized_shape,
weight: TensorLike | None = None,
eps: float | None = None,
) -> bool:
if not _is_cuda_tensor(a):
return False

if not _tilegym_device_cc_ok(a.device.index):
return False

if weight is None:
# TileGym rms_norm requires affine weight
return False
if not _is_cuda_tensor(weight) or weight.device != a.device:
return False
if a.dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32):
return False
if weight.dtype != a.dtype:
return False
# TileGym rms_norm doesn't implement backward yet.
# We only enable this when the *activation* does not require grad
# (typical inference usage).
if a.requires_grad:
return False
# normalized_shape is validated by the underlying op; keep checker minimal.
return True


def _tilegym_rms_norm_impl(
a: torch.Tensor,
normalized_shape,
weight: torch.Tensor | None = None,
eps: float | None = None,
) -> torch.Tensor:
if eps is None:
eps = torch.finfo(a.dtype).eps if a.dtype.is_floating_point else 0.0
# Checker ensures weight is present.
return tg_ops.rms_norm(a, normalized_shape, weight, eps)


TileGymTransform: Transform | None = None

if hasattr(ltorch, "rms_norm"):
tilegym_rms_norm = tilegym_ex.register_operator(
"tilegym_rms_norm",
like=ltorch.rms_norm,
fn=_tilegym_rms_norm_impl,
)
tilegym_ex.register_implementation(
ltorch.rms_norm,
op=tilegym_rms_norm,
checker=_tilegym_rms_norm_checker,
)
1 change: 1 addition & 0 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def get_all_executors() -> tuple[Executor, ...]:
pythonex,
sdpaex,
fa3ex,
tilegymex,
torch_compile,
torchex,
transformer_engineex,
Expand Down
54 changes: 54 additions & 0 deletions thunder/tests/test_tilegym_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

import thunder
from lightning_utilities.core.imports import package_available


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.skipif(not package_available("tilegym"), reason="requires tilegym")
def test_tilegym_executor_sdpa_rewrites_and_runs():
tilegym_ex = thunder.get_executor("tilegym")
assert tilegym_ex is not None

def fn(q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

# Choose a shape that avoids other SDPA executors' restrictions interfering with this test:
# - Head dim divisible by 8
# - No explicit attn_mask, no dropout
B, H, S, D = 2, 8, 256, 128
q = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)

jfn = thunder.jit(fn, executors=(tilegym_ex, *thunder.get_default_executors()))
out = jfn(q, k, v)
ref = fn(q, k, v)

torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)

trace = thunder.last_traces(jfn)[-1]
assert any(bsym.sym.executor is tilegym_ex for bsym in trace.bound_symbols)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.skipif(not package_available("tilegym"), reason="requires tilegym")
def test_tilegym_executor_rms_norm_rewrites_and_runs():
tilegym_ex = thunder.get_executor("tilegym")
assert tilegym_ex is not None

def fn(x, w):
return torch.nn.functional.rms_norm(x, (x.shape[-1],), w, 1e-6)

x = torch.randn(4, 128, device="cuda", dtype=torch.bfloat16, requires_grad=False)
w = torch.randn(128, device="cuda", dtype=torch.bfloat16, requires_grad=False)

jfn = thunder.jit(fn, executors=(tilegym_ex, *thunder.get_default_executors()))
out = jfn(x, w)
ref = fn(x, w)

torch.testing.assert_close(out, ref, atol=0, rtol=0)

trace = thunder.last_traces(jfn)[-1]
assert any(bsym.sym.executor is tilegym_ex for bsym in trace.bound_symbols)
Loading