Skip to content

Commit 0996865

Browse files
authored
Switch from TensorDescriptor to tl.make_tensor_descriptor (#232)
This rewrites a lot of the tensor descriptor handling to properly apply requirements.
1 parent 52306b0 commit 0996865

13 files changed

+773
-150
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,5 @@ venv
8686
.watchmanconfig
8787
*.zip
8888
CLAUDE.md
89+
triton
90+
torch

helion/_compat.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
import functools
4-
import importlib
54

65
import torch
76
from torch._inductor.runtime.hints import DeviceProperties
87
from torch._inductor.utils import triton_type
8+
import triton
99
from triton.backends.compiler import GPUTarget
1010
import triton.language as tl
1111

@@ -22,35 +22,7 @@ def _supports_tensor_descriptor() -> bool:
2222
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
2323
if major < 9:
2424
return False
25-
try:
26-
return get_triton_tensor_descriptor_class() is not None
27-
except ImportError:
28-
return False
29-
30-
31-
@functools.cache
32-
def get_triton_tensor_descriptor_class_import_path() -> str:
33-
cls = get_triton_tensor_descriptor_class()
34-
return f"from {cls.__module__} import {cls.__qualname__}"
35-
36-
37-
@functools.cache
38-
def get_triton_tensor_descriptor_class() -> type[object]:
39-
"""Attempt to import TensorDescriptor class from known Triton modules."""
40-
possible_modules = [
41-
"triton.tools.tensor_descriptor",
42-
"triton.tools.experimental_descriptor",
43-
]
44-
for module_name in possible_modules:
45-
try:
46-
module = importlib.import_module(module_name)
47-
if hasattr(module, "TensorDescriptor"):
48-
return module.TensorDescriptor
49-
except ImportError:
50-
continue
51-
raise ImportError(
52-
"TensorDescriptor class not found in any of the known Triton modules."
53-
)
25+
return hasattr(triton.language, "make_tensor_descriptor")
5426

5527

5628
@functools.cache

helion/_compiler/device_function.py

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,35 @@ def sort_key(self) -> tuple[object, ...]:
7474
@dataclasses.dataclass
7575
class TensorArg(Argument):
7676
fake_value: torch.Tensor
77-
_host_str: str
77+
_host_str: str | None
7878

7979
def host_str(self) -> str:
80+
if self._host_str is None:
81+
raise RuntimeError("TensorArg has no host representation")
8082
return self._host_str
8183

8284

8385
@dataclasses.dataclass
8486
class TensorDescriptorArg(TensorArg):
85-
pass
87+
# Permutation applied to make stride==1 dimension last
88+
permutation: list[int] | None = None
89+
90+
def host_str(self) -> str:
91+
if self._host_str is None:
92+
raise RuntimeError(
93+
"TensorDescriptorArg is device-only and has no host representation"
94+
)
95+
return self._host_str
96+
97+
@property
98+
def inverse_permutation(self) -> list[int]:
99+
"""Get the inverse permutation to undo the applied permutation."""
100+
if (permutation := self.permutation) is None:
101+
raise RuntimeError("TensorDescriptorArg.permutation is None")
102+
inverse_perm = [0] * len(permutation)
103+
for i, p in enumerate(permutation):
104+
inverse_perm[p] = i
105+
return inverse_perm
86106

87107

88108
@dataclasses.dataclass
@@ -144,6 +164,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
144164
self.config = config
145165
self.codegen = codegen
146166
self.arguments: list[Argument] = []
167+
self.preamble: list[ast.AST] = []
147168
self.body: list[ast.AST] = []
148169
self._tensor_args: dict[torch.Tensor, TensorArg] = {}
149170
self._tensor_descriptor_args: dict[
@@ -272,20 +293,59 @@ def tensor_arg(
272293

273294
def tensor_descriptor_arg(
274295
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
275-
) -> TensorArg:
296+
) -> TensorDescriptorArg:
276297
host_function = HostFunction.current()
277-
block_size_expr = ", ".join(
278-
map(HostFunction.current().literal_expr, block_size)
279-
)
298+
block_size_expr = ", ".join(map(self.literal_expr, block_size))
280299
key = (fake_value, block_size_expr)
281300
if key not in self._tensor_descriptor_args:
282301
origin = host_function.tensor_to_origin[fake_value]
302+
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
303+
env = CompileEnvironment.current()
304+
305+
# Find which dimension has stride==1
306+
stride_one_dim = [*map(env.size_hint, fake_value.stride())].index(1)
307+
308+
# Determine if we need permutation (stride==1 dimension is not last)
309+
permutation = None
310+
if stride_one_dim != fake_value.ndim - 1:
311+
# Create permutation to move stride==1 dimension to last position
312+
permutation = [*range(fake_value.ndim)]
313+
permutation.pop(stride_one_dim)
314+
permutation.append(stride_one_dim)
315+
316+
# Create the regular tensor arg and size/stride args
317+
tensor_arg = self.tensor_arg(fake_value)
318+
size_args = [
319+
self.tensor_size(fake_value, i) for i in range(fake_value.ndim)
320+
]
321+
stride_args = [
322+
self.tensor_stride(fake_value, i) for i in range(fake_value.ndim)
323+
]
324+
325+
# Apply permutation if needed
326+
if permutation is not None:
327+
size_args = [size_args[i] for i in permutation]
328+
stride_args = [stride_args[i] for i in permutation]
329+
block_size = [block_size[i] for i in permutation]
330+
# Update block_size_expr for the permuted order
331+
block_size_expr = ", ".join(map(self.literal_expr, block_size))
332+
333+
# Add tl.make_tensor_descriptor call to preamble
334+
sizes = ", ".join([arg.name for arg in size_args])
335+
strides = ", ".join([arg.name for arg in stride_args])
336+
337+
descriptor_stmt = statement_from_string(
338+
f"{desc_name} = tl.make_tensor_descriptor({tensor_arg.name}, [{sizes}], [{strides}], [{block_size_expr}])"
339+
)
340+
self.preamble.append(descriptor_stmt)
341+
283342
arg = TensorDescriptorArg(
284-
self.new_var(origin.suggest_var_name() + "_desc"),
343+
desc_name,
285344
fake_value,
286-
f"TensorDescriptor.from_tensor({origin.host_str()}, [{block_size_expr}])",
345+
None, # No host_str since this is device-only
346+
permutation,
287347
)
288-
self.arguments.append(arg)
348+
# Don't add to self.arguments since this is device-only
289349
self._tensor_descriptor_args[key] = arg
290350
return self._tensor_descriptor_args[key]
291351

@@ -342,20 +402,28 @@ def sorted_args(self) -> list[Argument]:
342402
self.arguments.sort(key=lambda arg: arg.sort_key())
343403
return self.arguments
344404

345-
def codegen_function_def(self) -> ast.FunctionDef:
346-
return ast_rename(
347-
create(
348-
ast.FunctionDef,
349-
name=self.name,
350-
args=create_arguments(
351-
[arg.arg_def_node() for arg in self.sorted_args()]
405+
def codegen_function_def(self) -> list[ast.stmt]:
406+
prefix = []
407+
if self._tensor_descriptor_args:
408+
prefix.append(
409+
statement_from_string("helion.runtime.set_triton_allocator()")
410+
)
411+
return [
412+
*prefix,
413+
ast_rename(
414+
create(
415+
ast.FunctionDef,
416+
name=self.name,
417+
args=create_arguments(
418+
[arg.arg_def_node() for arg in self.sorted_args()]
419+
),
420+
body=[*self.preamble, *self.body],
421+
decorator_list=[expr_from_string("triton.jit")],
422+
type_params=[],
352423
),
353-
body=self.body,
354-
decorator_list=[expr_from_string("triton.jit")],
355-
type_params=[],
424+
{k: v[0] for k, v in self._variable_renames.items()},
356425
),
357-
{k: v[0] for k, v in self._variable_renames.items()},
358-
)
426+
]
359427

360428
def codegen_function_call(self) -> ast.AST:
361429
args = [arg.host_str() for arg in self.sorted_args()]
@@ -390,14 +458,15 @@ def dead_code_elimination(self) -> None:
390458
"""
391459

392460
for _ in range(8):
393-
rw = ReadWrites.from_list(self.body)
461+
rw = ReadWrites.from_list([*self.preamble, *self.body])
394462
to_remove = set()
395463
for name in self.dce_vars:
396464
if name in rw.writes and name not in rw.reads:
397465
to_remove.add(name)
398466
if not to_remove:
399467
break
400468
self.body[:] = ast_delete_assignments(self.body, to_remove)
469+
self.preamble[:] = ast_delete_assignments(self.preamble, to_remove)
401470

402471
# drop any unused args
403472
args_to_remove = {

helion/_compiler/generate_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
407407
result = ast.Module(
408408
[
409409
*func.codegen_imports(),
410-
kernel_def,
410+
*kernel_def,
411411
host_def,
412412
precompile_def,
413413
],

helion/_compiler/indexing_strategy.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
if TYPE_CHECKING:
2020
from ..runtime.config import Config
21+
from .device_function import TensorDescriptorArg
2122
from .inductor_lowering import CodegenState
2223

2324

@@ -145,28 +146,71 @@ def codegen_store(
145146
class TensorDescriptorIndexingStrategy(IndexingStrategy):
146147
"""Use TensorDescriptor to load/store from tensors"""
147148

148-
def codegen_load(
149-
self,
149+
@staticmethod
150+
def is_supported(
150151
state: CodegenState,
151152
fake_tensor: torch.Tensor,
152153
subscript: list[object],
153154
extra_mask: ast.AST | None,
154-
) -> ast.AST:
155+
) -> bool:
156+
"""Check if tensor descriptor indexing is supported with additional requirements."""
157+
# First check the basic BlockedSubscriptIndexing requirements
155158
if not BlockedSubscriptIndexing.is_supported(
156159
state, fake_tensor, subscript, extra_mask
157160
):
161+
return False
162+
163+
# Additional tensor descriptor requirements:
164+
# 1) ndim must be between 2 and 5
165+
if not (2 <= fake_tensor.ndim <= 5):
166+
return False
167+
168+
# 2) Exactly 1 dimension should have stride==1
169+
env = CompileEnvironment.current()
170+
stride_one_count = 0
171+
element_size = fake_tensor.element_size()
172+
for dim in range(fake_tensor.ndim):
173+
stride = env.size_hint(fake_tensor.stride(dim))
174+
if stride == 1:
175+
stride_one_count += 1
176+
else:
177+
# 3) All other dimensions should have 16-byte aligned strides
178+
byte_stride = stride * element_size
179+
if byte_stride % 16 != 0:
180+
return False
181+
182+
# TODO(jansel): check that base_ptr is aligned to 16 bytes
183+
return stride_one_count == 1
184+
185+
def codegen_load(
186+
self,
187+
state: CodegenState,
188+
fake_tensor: torch.Tensor,
189+
subscript: list[object],
190+
extra_mask: ast.AST | None,
191+
) -> ast.AST:
192+
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
158193
return PointerIndexingStrategy().codegen_load(
159194
state, fake_tensor, subscript, extra_mask
160195
)
161196
assert extra_mask is None
162197
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
163-
return indexing.reshape_load(
164-
state,
165-
expr_from_string(
166-
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str()})"
167-
),
198+
199+
# Load from tensor descriptor with permuted offsets
200+
load_expr = expr_from_string(
201+
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"
168202
)
169203

204+
# Apply inverse permutation to the loaded result if needed
205+
desc_arg = indexing.tensor_descriptor_arg(state)
206+
if desc_arg.permutation is not None:
207+
load_expr = expr_from_string(
208+
f"tl.permute(load_result, {desc_arg.inverse_permutation!r})",
209+
load_result=load_expr,
210+
)
211+
212+
return indexing.reshape_load(state, load_expr)
213+
170214
def codegen_store(
171215
self,
172216
state: CodegenState,
@@ -175,17 +219,27 @@ def codegen_store(
175219
value: ast.AST,
176220
extra_mask: ast.AST | None,
177221
) -> ast.AST:
178-
if not BlockedSubscriptIndexing.is_supported(
179-
state, fake_tensor, subscript, extra_mask
180-
):
222+
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
181223
return PointerIndexingStrategy().codegen_store(
182224
state, fake_tensor, subscript, value, extra_mask
183225
)
184226
assert extra_mask is None
185227
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
228+
229+
# Apply permutation to the value being stored if needed
230+
desc_arg = indexing.tensor_descriptor_arg(state)
231+
store_value = indexing.reshape_store(state, value)
232+
233+
if desc_arg.permutation is not None:
234+
# Apply permutation to the value
235+
store_value = expr_from_string(
236+
f"tl.permute(store_val, {desc_arg.permutation!r})",
237+
store_val=store_value,
238+
)
239+
186240
return expr_from_string(
187-
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str()}, value)",
188-
value=indexing.reshape_store(state, value),
241+
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, value)",
242+
value=store_value,
189243
)
190244

191245

@@ -371,9 +425,21 @@ def tensor_descriptor(self, state: CodegenState) -> str:
371425
self.base, self.block_shape
372426
).name
373427

428+
def tensor_descriptor_arg(self, state: CodegenState) -> TensorDescriptorArg:
429+
return state.device_function.tensor_descriptor_arg(self.base, self.block_shape)
430+
374431
def offsets_str(self) -> str:
375432
return f"[{', '.join(self.offsets)}]"
376433

434+
def offsets_str_permuted(self, state: CodegenState) -> str:
435+
"""Get offsets string with permutation applied if needed."""
436+
desc_arg = self.tensor_descriptor_arg(state)
437+
if desc_arg.permutation is not None:
438+
# Apply permutation to offsets
439+
permuted_offsets = [self.offsets[i] for i in desc_arg.permutation]
440+
return f"[{', '.join(permuted_offsets)}]"
441+
return self.offsets_str()
442+
377443
@property
378444
def ndim(self) -> int:
379445
return self.base.ndim
@@ -427,7 +493,6 @@ def is_supported(
427493
index: list[object],
428494
extra_mask: ast.AST | None,
429495
) -> bool:
430-
# TODO(jansel): TensorDescriptor has some extra restrictions that are not captured here.
431496
if extra_mask is not None:
432497
# TODO(jansel): support block_ptr with extra_mask
433498
return False

0 commit comments

Comments
 (0)