Skip to content

Implement persistent kernels #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,10 @@ Specifies the type of indexing code to generate. The `"tensor_descriptor"`
option uses Tensor Memory Accelerators (TMAs) but requires a Hopper or
newer GPU and the latest development version of Triton.

* **use\_yz\_grid** (`bool`):
Determines if the `y` and `z` dimensions of the launch grid are utilized,
or if only the `x` dimension is used. This option is ignored if `l2_groupings[0] > 1`.
* **pid\_type** (`"flat"`, `"xyz"`, `"persistent_blocked"`, or `"persistent_interleaved"`):
Specifies the program ID mapping strategy. `"flat"` uses only the x-dimension,
`"xyz"` utilizes multiple grid dimensions, and persistent strategies enable
persistent kernels for improved SM utilization.

* **num\_warps** (`int`):
Sets the number of warps the kernel will use.
Expand Down
5 changes: 2 additions & 3 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from ..runtime.config import Config
from .generate_ast import GenerateAST
from .program_id import ProgramIDs
from .program_id import SharedProgramID

_P = TypeVar("_P", bound="TensorPropertyArg")

Expand Down Expand Up @@ -178,7 +177,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self._unique_counter: dict[str, itertools.count[int]] = defaultdict(
itertools.count
)
self.pid: SharedProgramID | ProgramIDs | None = None
self.pid: ProgramIDs | None = None
self.namespace: _Namespace = _Namespace()
self.namespace._used_names.update(reserved_names())
self._variable_renames: dict[str, list[str]] = {}
Expand All @@ -203,7 +202,7 @@ def merge_variable_names(self, a: str, b: str) -> None:
for n in name_group:
self._variable_renames[n] = name_group

def set_pid(self, pid: SharedProgramID | ProgramIDs) -> None:
def set_pid(self, pid: ProgramIDs) -> None:
assert self.pid is None, "pid already set"
self.pid = pid

Expand Down
4 changes: 2 additions & 2 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
remove_unnecessary_masking(graph.graph)
device_ir.build_rolled_reductions()
if len(device_ir.root_ids) > 1:
# yz_grid not supported with shared program IDs
CompileEnvironment.current().config_spec.allow_use_yz_grid = False
# xyz not supported with shared program IDs, but persistent kernels are allowed
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
return device_ir


Expand Down
14 changes: 10 additions & 4 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .device_function import DeviceFunction
from .inductor_lowering import CodegenState
from .inductor_lowering import codegen_call_with_graph
from .program_id import SharedProgramID
from .program_id import ForEachProgramID
from .variable_origin import ArgumentOrigin

if TYPE_CHECKING:
Expand Down Expand Up @@ -156,11 +156,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None:

if node._root_id == 0:
self.device_function.set_pid(
SharedProgramID(
ForEachProgramID(
self.device_function.new_var("pid_shared", dce=False)
)
)
self.device_function.body.append(
self.device_function.body.extend(
self.device_function.pid.codegen_pid_init()
)
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
Expand Down Expand Up @@ -231,8 +231,14 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
orelse=self.next_else_block,
)
)
self.device_function.dead_code_elimination()
if node._root_id == len(self.host_function.device_ir.root_ids) - 1:
if self.device_function.pid is not None:
persistent_body = self.device_function.pid.setup_persistent_kernel(
self.device_function
)
if persistent_body is not None:
self.device_function.body = persistent_body
self.device_function.dead_code_elimination()
return self.device_function.codegen_function_call()
return None
return self.generic_visit(node)
Expand Down
1 change: 1 addition & 0 deletions helion/_compiler/output_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
[
SOURCE_MODULE,
"make_precompiler",
"_NUM_SM",
]
)

Expand Down
Loading
Loading