Skip to content

Commit 687ab9b

Browse files
committed
Implement persistent kernels
Enabled with `config["pid_type"]="persistent_blocked"` or `"persistent_interleaved"`. This also refactors much of the program id handling.
1 parent 0ec7b0f commit 687ab9b

17 files changed

+1583
-179
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,10 @@ Specifies the type of indexing code to generate. The `"tensor_descriptor"`
233233
option uses Tensor Memory Accelerators (TMAs) but requires a Hopper or
234234
newer GPU and the latest development version of Triton.
235235

236-
* **use\_yz\_grid** (`bool`):
237-
Determines if the `y` and `z` dimensions of the launch grid are utilized,
238-
or if only the `x` dimension is used. This option is ignored if `l2_groupings[0] > 1`.
236+
* **pid\_type** (`"flat"`, `"xyz"`, `"persistent_blocked"`, or `"persistent_interleaved"`):
237+
Specifies the program ID mapping strategy. `"flat"` uses only the x-dimension,
238+
`"xyz"` utilizes multiple grid dimensions, and persistent strategies enable
239+
persistent kernels for improved SM utilization.
239240

240241
* **num\_warps** (`int`):
241242
Sets the number of warps the kernel will use.

helion/_compiler/device_function.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from ..runtime.config import Config
4040
from .generate_ast import GenerateAST
4141
from .program_id import ProgramIDs
42-
from .program_id import SharedProgramID
4342

4443
_P = TypeVar("_P", bound="TensorPropertyArg")
4544

@@ -178,7 +177,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
178177
self._unique_counter: dict[str, itertools.count[int]] = defaultdict(
179178
itertools.count
180179
)
181-
self.pid: SharedProgramID | ProgramIDs | None = None
180+
self.pid: ProgramIDs | None = None
182181
self.namespace: _Namespace = _Namespace()
183182
self.namespace._used_names.update(reserved_names())
184183
self._variable_renames: dict[str, list[str]] = {}
@@ -203,7 +202,7 @@ def merge_variable_names(self, a: str, b: str) -> None:
203202
for n in name_group:
204203
self._variable_renames[n] = name_group
205204

206-
def set_pid(self, pid: SharedProgramID | ProgramIDs) -> None:
205+
def set_pid(self, pid: ProgramIDs) -> None:
207206
assert self.pid is None, "pid already set"
208207
self.pid = pid
209208

helion/_compiler/device_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
893893
remove_unnecessary_masking(graph.graph)
894894
device_ir.build_rolled_reductions()
895895
if len(device_ir.root_ids) > 1:
896-
# yz_grid not supported with shared program IDs
897-
CompileEnvironment.current().config_spec.allow_use_yz_grid = False
896+
# xyz not supported with shared program IDs, but persistent kernels are allowed
897+
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
898898
return device_ir
899899

900900

helion/_compiler/generate_ast.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .device_function import DeviceFunction
2020
from .inductor_lowering import CodegenState
2121
from .inductor_lowering import codegen_call_with_graph
22-
from .program_id import SharedProgramID
22+
from .program_id import ForEachProgramID
2323
from .variable_origin import ArgumentOrigin
2424

2525
if TYPE_CHECKING:
@@ -156,11 +156,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
156156

157157
if node._root_id == 0:
158158
self.device_function.set_pid(
159-
SharedProgramID(
159+
ForEachProgramID(
160160
self.device_function.new_var("pid_shared", dce=False)
161161
)
162162
)
163-
self.device_function.body.append(
163+
self.device_function.body.extend(
164164
self.device_function.pid.codegen_pid_init()
165165
)
166166
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
@@ -231,8 +231,14 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
231231
orelse=self.next_else_block,
232232
)
233233
)
234-
self.device_function.dead_code_elimination()
235234
if node._root_id == len(self.host_function.device_ir.root_ids) - 1:
235+
if self.device_function.pid is not None:
236+
persistent_body = self.device_function.pid.setup_persistent_kernel(
237+
self.device_function
238+
)
239+
if persistent_body is not None:
240+
self.device_function.body = persistent_body
241+
self.device_function.dead_code_elimination()
236242
return self.device_function.codegen_function_call()
237243
return None
238244
return self.generic_visit(node)

helion/_compiler/output_header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
[
2828
SOURCE_MODULE,
2929
"make_precompiler",
30+
"_NUM_SM",
3031
]
3132
)
3233

0 commit comments

Comments
 (0)