From a6d50312c6a9a741fbb5fc194f707c0eee215692 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 3 Jul 2025 11:31:51 -0700 Subject: [PATCH] Implement persistent kernels Enabled with `config["pid_type"]="persistent_blocked"` or `"persistent_interleaved"`. This also refactors much of the program id handling. stack-info: PR: https://github.com/pytorch-labs/helion/pull/238, branch: jansel/stack/77 --- README.md | 7 +- helion/_compiler/device_function.py | 5 +- helion/_compiler/device_ir.py | 4 +- helion/_compiler/generate_ast.py | 14 +- helion/_compiler/output_header.py | 1 + helion/_compiler/program_id.py | 493 +++++++++++--- helion/_compiler/tile_strategy.py | 42 +- helion/autotuner/config_spec.py | 57 +- helion/language/loops.py | 5 +- helion/runtime/__init__.py | 14 + helion/runtime/config.py | 11 +- test/test_autotuner.py | 40 +- test/test_examples.py | 2 +- test/test_generate_ast.py | 2 +- test/test_persistent_kernels.py | 974 ++++++++++++++++++++++++++++ test/test_register_tunable.py | 2 +- 16 files changed, 1494 insertions(+), 179 deletions(-) create mode 100644 test/test_persistent_kernels.py diff --git a/README.md b/README.md index 83bc690f..fac68dcf 100644 --- a/README.md +++ b/README.md @@ -233,9 +233,10 @@ Specifies the type of indexing code to generate. The `"tensor_descriptor"` option uses Tensor Memory Accelerators (TMAs) but requires a Hopper or newer GPU and the latest development version of Triton. -* **use\_yz\_grid** (`bool`): - Determines if the `y` and `z` dimensions of the launch grid are utilized, - or if only the `x` dimension is used. This option is ignored if `l2_groupings[0] > 1`. +* **pid\_type** (`"flat"`, `"xyz"`, `"persistent_blocked"`, or `"persistent_interleaved"`): + Specifies the program ID mapping strategy. `"flat"` uses only the x-dimension, + `"xyz"` utilizes multiple grid dimensions, and persistent strategies enable + persistent kernels for improved SM utilization. * **num\_warps** (`int`): Sets the number of warps the kernel will use. diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 877fabbd..3af3fd90 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -39,7 +39,6 @@ from ..runtime.config import Config from .generate_ast import GenerateAST from .program_id import ProgramIDs - from .program_id import SharedProgramID _P = TypeVar("_P", bound="TensorPropertyArg") @@ -178,7 +177,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self._unique_counter: dict[str, itertools.count[int]] = defaultdict( itertools.count ) - self.pid: SharedProgramID | ProgramIDs | None = None + self.pid: ProgramIDs | None = None self.namespace: _Namespace = _Namespace() self.namespace._used_names.update(reserved_names()) self._variable_renames: dict[str, list[str]] = {} @@ -203,7 +202,7 @@ def merge_variable_names(self, a: str, b: str) -> None: for n in name_group: self._variable_renames[n] = name_group - def set_pid(self, pid: SharedProgramID | ProgramIDs) -> None: + def set_pid(self, pid: ProgramIDs) -> None: assert self.pid is None, "pid already set" self.pid = pid diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index eb86de87..59f38086 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -893,8 +893,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: remove_unnecessary_masking(graph.graph) device_ir.build_rolled_reductions() if len(device_ir.root_ids) > 1: - # yz_grid not supported with shared program IDs - CompileEnvironment.current().config_spec.allow_use_yz_grid = False + # xyz not supported with shared program IDs, but persistent kernels are allowed + CompileEnvironment.current().config_spec.disallow_pid_type("xyz") return device_ir diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 41e3f209..85e0338f 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -19,7 +19,7 @@ from .device_function import DeviceFunction from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph -from .program_id import SharedProgramID +from .program_id import ForEachProgramID from .variable_origin import ArgumentOrigin if TYPE_CHECKING: @@ -156,11 +156,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None: if node._root_id == 0: self.device_function.set_pid( - SharedProgramID( + ForEachProgramID( self.device_function.new_var("pid_shared", dce=False) ) ) - self.device_function.body.append( + self.device_function.body.extend( self.device_function.pid.codegen_pid_init() ) if node._root_id < len(self.host_function.device_ir.root_ids) - 1: @@ -231,8 +231,14 @@ def visit_For(self, node: ast.For) -> ast.AST | None: orelse=self.next_else_block, ) ) - self.device_function.dead_code_elimination() if node._root_id == len(self.host_function.device_ir.root_ids) - 1: + if self.device_function.pid is not None: + persistent_body = self.device_function.pid.setup_persistent_kernel( + self.device_function + ) + if persistent_body is not None: + self.device_function.body = persistent_body + self.device_function.dead_code_elimination() return self.device_function.codegen_function_call() return None return self.generic_visit(node) diff --git a/helion/_compiler/output_header.py b/helion/_compiler/output_header.py index 78f6f0ca..bc0ec205 100644 --- a/helion/_compiler/output_header.py +++ b/helion/_compiler/output_header.py @@ -27,6 +27,7 @@ [ SOURCE_MODULE, "make_precompiler", + "_NUM_SM", ] ) diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 6d85ca67..72474425 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -1,68 +1,189 @@ from __future__ import annotations +import abc +import ast import dataclasses from typing import TYPE_CHECKING from typing import NamedTuple -from helion._compiler.ast_extension import expr_from_string -from helion._compiler.ast_extension import statement_from_string -from helion._compiler.host_function import HostFunction +from .ast_extension import expr_from_string +from .ast_extension import statement_from_string +from .compile_environment import CompileEnvironment +from .device_function import DeviceFunction +from .host_function import HostFunction if TYPE_CHECKING: - import ast - import sympy - from helion._compiler.inductor_lowering import CodegenState + from .inductor_lowering import CodegenState + +NUM_SM_VAR = "_NUM_SM" -class ProgramID(NamedTuple): +class PIDInfo(NamedTuple): pid_var: str block_size_var: str numel: sympy.Expr - def host_cdiv(self) -> str: - numel_str = HostFunction.current().sympy_expr(self.numel) + def num_pids_expr(self, *, is_device: bool) -> str: + """Get the number of PIDs expression for device or host.""" + if is_device: + context = DeviceFunction.current() + cdiv_func = "tl.cdiv" + else: + context = HostFunction.current() + cdiv_func = "triton.cdiv" + numel_str = context.sympy_expr(self.numel) if self.block_size_var == "1": return numel_str - return f"triton.cdiv({numel_str}, {self.block_size_var})" + return f"{cdiv_func}({numel_str}, {self.block_size_var})" - def device_cdiv(self, state: CodegenState) -> str: - numel_str = state.sympy_expr(self.numel) - if self.block_size_var == "1": - return numel_str - return f"tl.cdiv({numel_str}, {self.block_size_var})" + +@dataclasses.dataclass +class ProgramIDs(abc.ABC): + """Base class for all program ID strategies with common functionality.""" + + shared_pid_var: str | None = None + pid_info: list[PIDInfo] = dataclasses.field(default_factory=list) + + def append(self, pid: PIDInfo) -> None: + self.pid_info.append(pid) + + @abc.abstractmethod + def codegen(self, state: CodegenState) -> None: + raise NotImplementedError + + @abc.abstractmethod + def codegen_grid(self) -> ast.AST: + """Generate grid launch expression for kernel execution.""" + raise NotImplementedError + + def total_pids_expr(self, *, is_device: bool) -> str: + """Get total PIDs expression for device or host.""" + return " * ".join( + f"({pid.num_pids_expr(is_device=is_device)})" for pid in self.pid_info + ) + + def setup_persistent_kernel( + self, device_function: DeviceFunction, total_pids_expr: str | None = None + ) -> list[ast.stmt] | None: + """Setup persistent kernel if supported. Returns None if not a persistent kernel.""" + return None + + def _setup_persistent_kernel_and_wrap_body( + self, + device_function: DeviceFunction, + virtual_pid_var: str, + range_expr: str, + total_pids_expr: str | None = None, + ) -> list[ast.stmt]: + """Complete persistent kernel setup: prepare body, wrap in loop, and return.""" + from .ast_extension import create + + # Prepare body for persistent loop + wrapped_body = list(device_function.body) + if isinstance(device_function.pid, ForEachProgramID): + shared_pid_var = device_function.pid.shared_pid_var + wrapped_body = [ + statement_from_string(f"{shared_pid_var} = {virtual_pid_var}"), + *wrapped_body, + ] + + # Create the persistent loop that wraps the entire body + persistent_loop = create( + ast.For, + target=create(ast.Name, id=virtual_pid_var, ctx=ast.Store()), + iter=expr_from_string(range_expr), + body=wrapped_body, + orelse=[], + type_comment=None, + ) + return [persistent_loop] + + @property + def virtual_program_id(self) -> str: + """Get the virtual program ID expression for this strategy.""" + return "tl.program_id(0)" + + def _is_persistent(self) -> bool: + """Check if this is a persistent strategy. Default False.""" + return False + + def _decompose_pid_to_statements( + self, pid_var: str, state: CodegenState + ) -> list[ast.stmt]: + """Generate statements to decompose a single PID variable into multiple PID components.""" + num_blocks = [ + state.device_function.new_var(f"num_blocks_{i}") + for i in range(len(self.pid_info[:-1])) + ] + statements = [ + statement_from_string(f"{num_block} = {pid.num_pids_expr(is_device=True)}") + for num_block, pid in zip(num_blocks, self.pid_info[:-1], strict=True) + ] + for i, pid in enumerate(self.pid_info): + expr = pid_var + if i > 0: + divisor = " * ".join(num_blocks[:i]) + expr = f"({expr}) // ({divisor})" + if i + 1 < len(self.pid_info): + expr = f"({expr}) % ({num_blocks[i]})" + statements.append(statement_from_string(f"{pid.pid_var} = {expr}")) + return statements @dataclasses.dataclass -class SharedProgramID: +class ForEachProgramID(ProgramIDs): """ - Use the same PID for all blocks - TODO(oulgen): Currently only supports 1 dimension + Represent multiple top level for loops in the Helion kernel. Turns into `if` statements in generated code. """ shared_pid_var: str - pids: list[ProgramIDs] = dataclasses.field(default_factory=list) - - def codegen_pid_init( - self, - ) -> ast.stmt: - return statement_from_string(f"{self.shared_pid_var} = tl.program_id(0)") - - def codegen_test(self, state: CodegenState) -> ast.AST: + cases: list[ProgramIDs] = dataclasses.field(default_factory=list) + pid_info: list[PIDInfo] = dataclasses.field(default_factory=list, init=False) + + def codegen_pid_init(self) -> list[ast.stmt]: + # Check if persistent kernels are enabled in config - if so, skip regular initialization + # as it will be handled by the persistent loop wrapper + from .device_function import DeviceFunction + + current_device_fn = DeviceFunction.current() + pid_type = current_device_fn.config.get("pid_type", "flat") + if isinstance(pid_type, str) and pid_type.startswith("persistent"): + return [] + return [statement_from_string(f"{self.shared_pid_var} = tl.program_id(0)")] + + def _get_cdiv_blocks( + self, state: CodegenState, exclude_last: bool = False + ) -> list[str]: + """Get non-empty cdiv expressions from cases.""" + cases = self.cases[:-1] if exclude_last else self.cases blocks = [] - for pid in self.pids: - blocks.append(pid.combined_device_cdiv(state)) + for pid in cases: + cdiv = pid.total_pids_expr(is_device=True) + if cdiv: # Only add non-empty cdiv expressions + blocks.append(cdiv) + return blocks - assert len(blocks) > 0 + def codegen_test(self, state: CodegenState) -> ast.AST: + blocks = self._get_cdiv_blocks(state) return expr_from_string(f"{self.shared_pid_var} < ({'+ '.join(blocks)})") - def codegen(self, state: CodegenState) -> None: - # TODO(oulgen): We need CSE between codegen_test and codegen for shared device cdivs - blocks = [] - for pid in self.pids[:-1]: - blocks.append(pid.combined_device_cdiv(state)) + def setup_persistent_kernel( + self, device_function: DeviceFunction, total_pids_expr: str | None = None + ) -> list[ast.stmt] | None: + # Persistent type will be the same for every case, so we can use the first one + return self.cases[0].setup_persistent_kernel( + device_function, self.total_pids_expr(is_device=True) + ) + + def total_pids_expr(self, *, is_device: bool) -> str: + """Get total PIDs expression for ForEachProgramID (sum of all pids).""" + cdivs = [pid.total_pids_expr(is_device=is_device) for pid in self.cases] + return " + ".join(cdivs) + def codegen(self, state: CodegenState) -> None: + blocks = self._get_cdiv_blocks(state, exclude_last=True) if blocks: state.codegen.statements_stack[-1].insert( 0, @@ -72,120 +193,282 @@ def codegen(self, state: CodegenState) -> None: ) def codegen_grid(self) -> ast.AST: - return expr_from_string( - f"({'+ '.join(pid.combined_host_cdiv() for pid in self.pids)},)" - ) - - -@dataclasses.dataclass -class ProgramIDs: - pids: list[ProgramID] = dataclasses.field(default_factory=list) - shared_pid_var: str | None = None + # Check if any of the pids is a persistent strategy + if self.cases[0]._is_persistent(): + # Use SM count grid for persistent kernels + return self.cases[0].codegen_grid() - def append(self, pid: ProgramID) -> None: - self.pids.append(pid) + # When persistent kernels are not active, use the full grid size + host_cdivs = [pid.total_pids_expr(is_device=False) for pid in self.cases] + return expr_from_string(f"({'+ '.join(host_cdivs)},)") - def codegen(self, state: CodegenState) -> None: - raise NotImplementedError - - def codegen_grid(self) -> ast.AST: - raise NotImplementedError - - def combined_device_cdiv(self, state: CodegenState) -> str: - raise NotImplementedError - - def combined_host_cdiv(self) -> str: - raise NotImplementedError + def _prepare_persistent_body( + self, + body: list[ast.AST], + device_function: DeviceFunction, + virtual_pid_var: str, + ) -> list[ast.AST]: + """Prepare body for persistent loop - handle ForEachProgramID assignment.""" + # In persistent kernels, replace ForEachProgramID init with virtual_pid assignment + return [ + statement_from_string(f"{self.shared_pid_var} = {virtual_pid_var}"), + *body, + ] -class GridProgramIDs(ProgramIDs): +class XYZProgramIDs(ProgramIDs): """Use the cuda x/y/z launch grid for PIDs""" def codegen(self, state: CodegenState) -> None: - for i, pid in enumerate(self.pids): + for i, pid in enumerate(self.pid_info): state.codegen.statements_stack[-1].insert( i, statement_from_string(f"{pid.pid_var} = tl.program_id({i})") ) def codegen_grid(self) -> ast.AST: - assert len(self.pids) <= 3 - return expr_from_string(f"({', '.join(pid.host_cdiv() for pid in self.pids)},)") + assert len(self.pid_info) <= 3 + return expr_from_string( + f"({', '.join(pid.num_pids_expr(is_device=False) for pid in self.pid_info)},)" + ) -class VirtualProgramIDs(ProgramIDs): +class FlatProgramIDs(ProgramIDs): """Only use the x grid and compute other dimensions""" - def combined_device_cdiv(self, state: CodegenState) -> str: - return " * ".join(pid.device_cdiv(state) for pid in self.pids) - - def combined_host_cdiv(self) -> str: - return " * ".join(f"({pid.host_cdiv()})" for pid in self.pids) - def codegen(self, state: CodegenState) -> None: pid_var = self.shared_pid_var or "tl.program_id(0)" - - num_blocks = [ - state.device_function.new_var(f"num_blocks_{i}") - for i in range(len(self.pids[:-1])) - ] - statements = [ - statement_from_string(f"{num_block} = {pid.device_cdiv(state)}") - for num_block, pid in zip(num_blocks, self.pids[:-1], strict=True) - ] - for i, pid in enumerate(self.pids): - expr = pid_var - if i > 0: - divisor = " * ".join(num_blocks[:i]) - expr = f"({expr}) // ({divisor})" - if i + 1 < len(self.pids): - expr = f"({expr}) % ({num_blocks[i]})" - statements.append(statement_from_string(f"{pid.pid_var} = {expr}")) + statements = self._decompose_pid_to_statements(pid_var, state) state.codegen.statements_stack[-1][:] = [ *statements, *state.codegen.statements_stack[-1], ] def codegen_grid(self) -> ast.AST: - return expr_from_string(f"({self.combined_host_cdiv()},)") + return expr_from_string(f"({self.total_pids_expr(is_device=False)},)") @dataclasses.dataclass class L2GroupingProgramIDs(ProgramIDs): """Used grouped iteration order to promote L2 cache reuse in matmuls""" + pid_info: list[PIDInfo] = dataclasses.field(default_factory=list, init=False) + parent_strategy: ProgramIDs | None = dataclasses.field(default=None) group_size: int = 1 + def append(self, pid: PIDInfo) -> None: + """Delegate to parent strategy.""" + assert self.parent_strategy is not None + self.parent_strategy.append(pid) + def codegen(self, state: CodegenState) -> None: - assert len(self.pids) == 2 + # Generate L2 grouping logic + # Note: Persistent kernel setup is handled by ForEachProgramID if needed + assert self.parent_strategy is not None + parent_pids = self.parent_strategy.pid_info + assert len(parent_pids) == 2 new_var = state.device_function.new_var - pid = "tl.program_id(0)" + + # Use shared_pid_var if we're in a ForEachProgramID context, otherwise use virtual_program_id + if isinstance(state.device_function.pid, ForEachProgramID): + pid = state.device_function.pid.shared_pid_var + else: + pid = self.virtual_program_id + num_pid_m = new_var("num_pid_m") num_pid_n = new_var("num_pid_n") num_pid_in_group = new_var("num_pid_in_group") group_id = new_var("group_id") first_pid_m = new_var("first_pid_m") group_size_m = new_var("group_size_m") - state.codegen.statements_stack[-1][:] = [ - statement_from_string(f"{num_pid_m} = {self.pids[0].device_cdiv(state)}"), - statement_from_string(f"{num_pid_n} = {self.pids[1].device_cdiv(state)}"), - statement_from_string( - f"{num_pid_in_group} = {self.group_size} * {num_pid_n}" - ), - statement_from_string(f"{group_id} = {pid} // {num_pid_in_group}"), - statement_from_string(f"{first_pid_m} = {group_id} * {self.group_size}"), - statement_from_string( - f"{group_size_m} = min({num_pid_m} - {first_pid_m}, {self.group_size})" - ), - statement_from_string( - f"{self.pids[0].pid_var} = {first_pid_m} + (({pid} % {num_pid_in_group}) % {group_size_m})" - ), - statement_from_string( - f"{self.pids[1].pid_var} = ({pid} % {num_pid_in_group}) // {group_size_m}" + + assignments = [ + (num_pid_m, parent_pids[0].num_pids_expr(is_device=True)), + (num_pid_n, parent_pids[1].num_pids_expr(is_device=True)), + (num_pid_in_group, f"{self.group_size} * {num_pid_n}"), + (group_id, f"{pid} // {num_pid_in_group}"), + (first_pid_m, f"{group_id} * {self.group_size}"), + (group_size_m, f"min({num_pid_m} - {first_pid_m}, {self.group_size})"), + ( + parent_pids[0].pid_var, + f"{first_pid_m} + (({pid} % {num_pid_in_group}) % {group_size_m})", ), + (parent_pids[1].pid_var, f"({pid} % {num_pid_in_group}) // {group_size_m}"), + ] + statements = [ + statement_from_string(f"{var} = {expr}") for var, expr in assignments + ] + + state.codegen.statements_stack[-1][:] = [ + *statements, *state.codegen.statements_stack[-1], ] + @property + def virtual_program_id(self) -> str: + """Get the virtual program ID expression using parent strategy.""" + assert self.parent_strategy is not None + return self.parent_strategy.virtual_program_id + def codegen_grid(self) -> ast.AST: - return expr_from_string( - f"({' * '.join(pid.host_cdiv() for pid in self.pids)},)" + assert self.parent_strategy is not None + return self.parent_strategy.codegen_grid() + + def setup_persistent_kernel( + self, device_function: DeviceFunction, total_pids_expr: str | None = None + ) -> list[ast.stmt] | None: + """Delegate to parent strategy.""" + assert self.parent_strategy is not None + return self.parent_strategy.setup_persistent_kernel( + device_function, total_pids_expr ) + + def _is_persistent(self) -> bool: + """Forward to parent strategy.""" + assert self.parent_strategy is not None + return self.parent_strategy._is_persistent() + + def total_pids_expr(self, *, is_device: bool) -> str: + """Forward to parent strategy.""" + assert self.parent_strategy is not None + return self.parent_strategy.total_pids_expr(is_device=is_device) + + +class PersistentProgramIDs(ProgramIDs): + """Base class for persistent kernels that use num_sms grid size.""" + + def __init__(self, is_blocked: bool = False) -> None: + super().__init__() + self.is_blocked: bool = is_blocked + device_function = DeviceFunction.current() + self.virtual_pid_var: str = device_function.new_var("virtual_pid") + self.total_pids_var: str = device_function.new_var("total_pids") + # Generate variables and range expression based on strategy type + if self.is_blocked: + self.block_size_var: str = device_function.new_var("block_size") + self.start_pid_var: str = device_function.new_var("start_pid") + self.end_pid_var: str = device_function.new_var("end_pid") + self.range_expr: str = f"tl.range({self.start_pid_var}, {self.end_pid_var})" + else: + self.range_expr: str = ( + f"tl.range(tl.program_id(0), {self.total_pids_var}, {NUM_SM_VAR})" + ) + if device_function.constexpr_arg(NUM_SM_VAR): + device = CompileEnvironment.current().device + device_function.codegen.host_statements.append( + statement_from_string( + f"{NUM_SM_VAR} = helion.runtime.get_num_sm(torch.{device!r})" + ) + ) + + def codegen_grid(self) -> ast.AST: + # Use num_sms for persistent kernels + return expr_from_string(f"({NUM_SM_VAR},)") + + def setup_persistent_kernel( + self, device_function: DeviceFunction, total_pids_expr: str | None = None + ) -> list[ast.stmt] | None: + """Setup persistent kernel and return the wrapped body.""" + # Get total PIDs expression + if total_pids_expr is None: + total_pids_expr = self.total_pids_expr(is_device=True) + + # Generate setup statements + setup_statements = [ + statement_from_string(f"{self.total_pids_var} = {total_pids_expr}"), + ] + + # Add strategy-specific setup statements for blocked strategies + if self.is_blocked: + if self.block_size_var and self.start_pid_var and self.end_pid_var: + assignments = [ + ( + self.block_size_var, + f"tl.cdiv({self.total_pids_var}, {NUM_SM_VAR})", + ), + ( + self.start_pid_var, + f"tl.program_id(0) * {self.block_size_var}", + ), + ( + self.end_pid_var, + f"tl.minimum({self.start_pid_var} + {self.block_size_var}, {self.total_pids_var})", + ), + ] + setup_statements.extend( + [ + statement_from_string(f"{var} = {expr}") + for var, expr in assignments + ] + ) + + device_function.preamble.extend(setup_statements) + return self._setup_persistent_kernel_and_wrap_body( + device_function, self.virtual_pid_var, self.range_expr, total_pids_expr + ) + + def _is_persistent(self) -> bool: + """Check if this is a persistent strategy.""" + return True + + def _decompose_virtual_pid( + self, + state: CodegenState, + virtual_pid_var: str, + setup_statements: list[ast.stmt], + ) -> None: + """Decompose virtual PID into individual PID variables.""" + # Use shared_pid_var if available, otherwise virtual_pid_var + pid_var = self.shared_pid_var or virtual_pid_var + statements = self._decompose_pid_to_statements(pid_var, state) + setup_statements.extend(statements) + + def _generate_pid_statements(self, state: CodegenState) -> list[ast.stmt]: + """Generate PID decomposition statements based on setup state.""" + if not self.virtual_pid_var: + # Generate regular PID decomposition + return self._decompose_pid_to_statements( + self.shared_pid_var or "tl.program_id(0)", state + ) + + # Generate persistent PID decomposition + statements = [] + self._decompose_virtual_pid(state, self.virtual_pid_var, statements) + return statements + + def _prepend_statements( + self, state: CodegenState, statements: list[ast.stmt] + ) -> None: + """Prepend statements to current statement stack.""" + current_statements = state.codegen.statements_stack[-1] + current_statements[:] = [*statements, *current_statements] + + def codegen(self, state: CodegenState) -> None: + """Common codegen logic for persistent kernels.""" + is_shared_pid = isinstance(state.device_function.pid, ForEachProgramID) + + # Set up persistent loop if needed (non-ForEachProgramID case only) + if not is_shared_pid and not self.virtual_pid_var: + self.setup_persistent_kernel(state.device_function) + + # Generate and prepend PID decomposition statements + statements = self._generate_pid_statements(state) + self._prepend_statements(state, statements) + + @property + def virtual_program_id(self) -> str: + """Get the virtual program ID expression for persistent strategies.""" + return self.virtual_pid_var + + +class PersistentBlockedProgramIDs(PersistentProgramIDs): + """Persistent kernels where each SM processes a contiguous block of virtual PIDs.""" + + def __init__(self) -> None: + super().__init__(is_blocked=True) + + +class PersistentInterleavedProgramIDs(PersistentProgramIDs): + """Persistent kernels where each SM processes every num_sms-th virtual PID.""" + + def __init__(self) -> None: + super().__init__(is_blocked=False) diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 4871fb51..bec1e9ff 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -21,12 +21,14 @@ from .compile_environment import _has_unbacked from .compile_environment import _to_sympy from .host_function import HostFunction -from .program_id import GridProgramIDs +from .program_id import FlatProgramIDs +from .program_id import ForEachProgramID from .program_id import L2GroupingProgramIDs -from .program_id import ProgramID +from .program_id import PersistentBlockedProgramIDs +from .program_id import PersistentInterleavedProgramIDs +from .program_id import PIDInfo from .program_id import ProgramIDs -from .program_id import SharedProgramID -from .program_id import VirtualProgramIDs +from .program_id import XYZProgramIDs if TYPE_CHECKING: from collections.abc import Sequence @@ -388,6 +390,12 @@ def codegen_grid(self) -> ast.AST: f"(triton.cdiv({HostFunction.current().sympy_expr(total_numel)}, {block_size_var}), 1, 1)" ) + def codegen(self, state: CodegenState) -> None: + pass # No-op implementation for TmpPid + + def total_pids_expr(self, *, is_device: bool) -> str: + return "1" # Simple implementation for TmpPid + state.device_function.set_pid(TmpPid()) block_id_to_info = self._create_block_id_info_dict(state) @@ -491,7 +499,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: block_sizes = self.block_size assert len(block_sizes) == len(block_ids) pids = self.select_pid_strategy() - if isinstance(state.device_function.pid, SharedProgramID): + if isinstance(state.device_function.pid, ForEachProgramID): pids.shared_pid_var = state.device_function.pid.shared_pid_var assert state.ast_args is None @@ -542,12 +550,12 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: ) if mask_statement is not None: state.add_statement(mask_statement) - pid = ProgramID(pid_var, block_size_var, numel) + pid = PIDInfo(pid_var, block_size_var, numel) pids.append(pid) pids.codegen(state) - if isinstance(state.device_function.pid, SharedProgramID): + if isinstance(state.device_function.pid, ForEachProgramID): shared_pid = state.device_function.pid - shared_pid.pids.append(pids) + shared_pid.cases.append(pids) shared_pid.codegen(state) else: state.device_function.set_pid(pids) @@ -556,9 +564,16 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: return DeviceGridState(self, block_id_to_info=block_id_to_info) def select_pid_strategy(self) -> ProgramIDs: - if 1 < len(self.block_ids) <= 3 and self.fn.config.use_yz_grid: - return GridProgramIDs() - return VirtualProgramIDs() + pid_type = self.fn.config.pid_type + if pid_type == "xyz": + assert 1 < len(self.block_ids) <= 3 + return XYZProgramIDs() + if pid_type == "persistent_blocked": + return PersistentBlockedProgramIDs() + if pid_type == "persistent_interleaved": + return PersistentInterleavedProgramIDs() + assert pid_type == "flat" + return FlatProgramIDs() def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST: if isinstance(x, ast.AST): @@ -690,7 +705,10 @@ def _setup_mask( def select_pid_strategy(self) -> ProgramIDs: if self.l2_grouping > 1: - return L2GroupingProgramIDs(group_size=self.l2_grouping) + return L2GroupingProgramIDs( + group_size=self.l2_grouping, + parent_strategy=super().select_pid_strategy(), + ) return super().select_pid_strategy() diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index e8056cb1..df7c9fba 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -27,6 +27,9 @@ from collections.abc import Callable from collections.abc import Sequence + from helion.runtime.config import IndexingLiteral + from helion.runtime.config import PidTypeLiteral + DEFAULT_NUM_WARPS = 4 DEFAULT_NUM_STAGES = 3 VALID_KEYS: frozenset[str] = frozenset( @@ -43,10 +46,11 @@ "range_flattens", "num_warps", "num_stages", - "use_yz_grid", + "pid_type", "indexing", ] ) +VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved") @dataclasses.dataclass @@ -84,7 +88,17 @@ class ConfigSpec: user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field( default_factory=dict ) - allow_use_yz_grid: bool | None = None + allowed_pid_types: tuple[PidTypeLiteral, ...] = dataclasses.field( + default_factory=functools.partial(tuple, VALID_PID_TYPES) + ) + + @staticmethod + def _valid_indexing_types() -> tuple[IndexingLiteral, ...]: + return ( + ("pointer", "block_ptr", "tensor_descriptor") + if supports_tensor_descriptor() + else ("pointer", "block_ptr") + ) def _remove_duplicates(self) -> None: self.loop_orders._remove_duplicates() @@ -96,6 +110,14 @@ def _remove_duplicates(self) -> None: self.range_multi_buffers._remove_duplicates() self.range_flattens._remove_duplicates() + def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None: + """Disallow a pid_type from being used in the config.""" + # pyre-fixme[8] + self.allowed_pid_types = tuple( + [x for x in self.allowed_pid_types if x != pid_type] + ) + assert self.allowed_pid_types + def normalize(self, config: helion.Config | dict[str, object]) -> None: """Normalize the config to match the block_sizes and validate the config.""" if isinstance(config, helion.Config): @@ -154,10 +176,17 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: config.setdefault("num_stages", DEFAULT_NUM_STAGES) # TODO(jansel): include num_ctas and max_nreg - if self.allow_use_yz_grid: - config.setdefault("use_yz_grid", False) - - config.setdefault("indexing", "pointer") + for name, values in ( + ("pid_type", VALID_PID_TYPES), + ("indexing", self._valid_indexing_types()), + ): + if name in config: + if config[name] not in values: + raise InvalidConfig( + f"Invalid value for {name!r}: {config[name]!r} must be one of {[*values]!r}" + ) + else: + config[name] = values[0] # Allow tunable parameter keys in addition to VALID_KEYS allowed_keys = VALID_KEYS | {*self.user_defined_tunables.keys()} @@ -182,25 +211,13 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "range_flattens": self.range_flattens._flat_config(self, fn), "num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)), "num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)), - "indexing": fn( - EnumFragment( - ("pointer", "block_ptr", "tensor_descriptor") - if supports_tensor_descriptor() - else ("pointer", "block_ptr") - ) - ), + "indexing": fn(EnumFragment(self._valid_indexing_types())), + "pid_type": fn(EnumFragment(self.allowed_pid_types)), } # Add tunable parameters for key, fragment in self.user_defined_tunables.items(): config[key] = fn(fragment) - if self.allow_use_yz_grid: - use_yz_grid = fn(BooleanFragment()) - # pyre-ignore[16] - if (not config["l2_groupings"] or config["l2_groupings"][0] == 1) and ( - not config["flatten_loops"] or not config["flatten_loops"][0] - ): - config["use_yz_grid"] = use_yz_grid for name in ( "loop_orders", "flatten_loops", diff --git a/helion/language/loops.py b/helion/language/loops.py index 9fa0a4a1..addd8226 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -249,7 +249,8 @@ def _add_config_choices( if len(block_ids) == 2: # TODO(jansel): support L2 grouping with 3+ dims (and maybe non-grids?) config_spec.l2_groupings.append(L2GroupingSpec(block_ids)) - config_spec.allow_use_yz_grid = _allow_use_yz_grid(config_spec, block_ids) + if not _allow_use_yz_grid(config_spec, block_ids): + config_spec.disallow_pid_type("xyz") else: params = inspect.signature(triton.language.range).parameters for block_id in block_ids: @@ -279,7 +280,7 @@ def _supports_warp_specialize() -> bool: def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool: """Check if the yz grid is allowed based on the block sizes.""" - if not (1 < len(block_ids) <= 3 and config_spec.allow_use_yz_grid is None): + if not (1 < len(block_ids) <= 3): return False hint = 1 try: diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index ac258d7d..98ce513b 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -25,3 +25,17 @@ def set_triton_allocator() -> None: except ImportError: pass triton.set_allocator(_alloc_fn) + + +def get_num_sm(device: torch.device) -> int: + """ + Get the number of streaming multiprocessors (SMs) for the specified device. + + Args: + device: Device to query. + + Returns: + Grid size to use for a persistent kernel on the device. + """ + assert device.type == "cuda", "TODO: implement for other devices" + return torch.cuda.get_device_properties(device.index).multi_processor_count diff --git a/helion/runtime/config.py b/helion/runtime/config.py index ed31b2ad..41885420 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -11,6 +11,7 @@ from helion.autotuner.config_spec import DEFAULT_NUM_WARPS IndexingLiteral = Literal["pointer", "tensor_descriptor", "block_ptr"] +PidTypeLiteral = Literal["flat", "xyz", "persistent_blocked", "persistent_interleaved"] class Config(Mapping[str, object]): @@ -32,7 +33,7 @@ def __init__( range_flattens: list[bool | None] | None = None, num_warps: int | None = None, num_stages: int | None = None, - use_yz_grid: bool | None = None, + pid_type: PidTypeLiteral | None = None, indexing: IndexingLiteral | None = None, # For user-defined properties **kwargs: object, @@ -52,7 +53,7 @@ def __init__( range_flattens: Controls flatten parameter for tl.range calls. num_warps: Number of warps per block. num_stages: Number of stages for software pipelining. - use_yz_grid: Whether to use yz grid dimensions. + pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved"). indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr"). **kwargs: Additional user-defined configuration parameters. """ @@ -71,7 +72,7 @@ def __init__( "num_warps": num_warps, "num_stages": num_stages, "indexing": indexing, - "use_yz_grid": use_yz_grid, + "pid_type": pid_type, } for key, value in core_props.items(): if value is not None: @@ -150,8 +151,8 @@ def l2_groupings(self) -> list[int]: return cast("list[int]", self.config.get("l2_groupings", [])) @property - def use_yz_grid(self) -> bool: - return cast("bool", self.config.get("use_yz_grid", False)) + def pid_type(self) -> PidTypeLiteral: + return cast("PidTypeLiteral", self.config.get("pid_type", "flat")) @property def range_unroll_factors(self) -> list[int]: diff --git a/test/test_autotuner.py b/test/test_autotuner.py index ecc990aa..a1699978 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -46,16 +46,16 @@ def test_config_fragment0(self): self.assertExpectedInline( "\n".join(map(repr, configs)), """\ -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=2, num_stages=8, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], num_warps=32, num_stages=2, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer') -helion.Config(block_sizes=[256, 128, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], num_warps=1, num_stages=8, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr') -helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=1, indexing='tensor_descriptor')""", +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat') +helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor', pid_type='flat') +helion.Config(block_sizes=[32, 32, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=16, num_stages=6, indexing='block_ptr', pid_type='persistent_interleaved') +helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_warp_specializes=[True], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=1, num_stages=4, indexing='block_ptr', pid_type='persistent_interleaved') +helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[4], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=16, num_stages=2, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=2, indexing='tensor_descriptor', pid_type='flat') +helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[4], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=16, num_stages=1, indexing='tensor_descriptor', pid_type='persistent_blocked') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr', pid_type='flat') +helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[None], range_flattens=[False], num_warps=16, num_stages=6, indexing='block_ptr', pid_type='flat') +helion.Config(block_sizes=[32, 32, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=4, num_stages=3, indexing='tensor_descriptor', pid_type='persistent_blocked')""", ) @patch.object(_compat, "_supports_tensor_descriptor", lambda: True) @@ -69,16 +69,16 @@ def test_config_fragment1(self): self.assertExpectedInline( "\n".join(map(repr, configs)), """\ -helion.Config(block_sizes=[8, 16, 16], loop_orders=[[0, 1, 2]], flatten_loops=[False], num_warps=4, num_stages=3, indexing='pointer') -helion.Config(block_sizes=[1, 64, 64], loop_orders=[[1, 2, 0]], flatten_loops=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor') -helion.Config(block_sizes=[1, 64, 512], loop_orders=[[0, 1, 2]], flatten_loops=[True], num_warps=4, num_stages=4, indexing='pointer') -helion.Config(block_sizes=[1, 64, 128], loop_orders=[[1, 2, 0]], flatten_loops=[True], num_warps=1, num_stages=2, indexing='pointer') -helion.Config(block_sizes=[2, 16, 64], loop_orders=[[2, 0, 1]], flatten_loops=[False], num_warps=2, num_stages=5, indexing='tensor_descriptor') -helion.Config(block_sizes=[8, 1, 16], loop_orders=[[2, 1, 0]], flatten_loops=[True], num_warps=16, num_stages=7, indexing='pointer') -helion.Config(block_sizes=[1, 8, 512], loop_orders=[[1, 0, 2]], flatten_loops=[True], num_warps=8, num_stages=4, indexing='tensor_descriptor') -helion.Config(block_sizes=[2, 2, 32], loop_orders=[[2, 0, 1]], flatten_loops=[True], num_warps=1, num_stages=8, indexing='pointer') -helion.Config(block_sizes=[2, 16, 2], loop_orders=[[0, 2, 1]], flatten_loops=[True], num_warps=4, num_stages=7, indexing='block_ptr') -helion.Config(block_sizes=[2, 16, 2], loop_orders=[[0, 2, 1]], flatten_loops=[True], num_warps=1, num_stages=3, indexing='block_ptr')""", +helion.Config(block_sizes=[8, 16, 16], loop_orders=[[0, 1, 2]], flatten_loops=[False], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat') +helion.Config(block_sizes=[1, 8, 8], loop_orders=[[1, 2, 0]], flatten_loops=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_blocked') +helion.Config(block_sizes=[8, 512, 8], loop_orders=[[2, 1, 0]], flatten_loops=[True], num_warps=4, num_stages=4, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[1, 8, 16], loop_orders=[[1, 2, 0]], flatten_loops=[True], num_warps=1, num_stages=2, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[4, 64, 2], loop_orders=[[2, 0, 1]], flatten_loops=[False], num_warps=4, num_stages=8, indexing='pointer', pid_type='persistent_blocked') +helion.Config(block_sizes=[8, 256, 4], loop_orders=[[1, 2, 0]], flatten_loops=[False], num_warps=2, num_stages=5, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[2, 64, 64], loop_orders=[[0, 2, 1]], flatten_loops=[True], num_warps=32, num_stages=2, indexing='tensor_descriptor', pid_type='flat') +helion.Config(block_sizes=[4, 2, 128], loop_orders=[[0, 1, 2]], flatten_loops=[True], num_warps=1, num_stages=1, indexing='tensor_descriptor', pid_type='persistent_interleaved') +helion.Config(block_sizes=[1, 32, 32], loop_orders=[[1, 2, 0]], flatten_loops=[True], num_warps=16, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_blocked') +helion.Config(block_sizes=[1, 512, 4], loop_orders=[[2, 1, 0]], flatten_loops=[False], num_warps=2, num_stages=6, indexing='block_ptr', pid_type='persistent_interleaved')""", ) def test_save_load_config(self): diff --git a/test/test_examples.py b/test/test_examples.py index 550aa64d..35307b22 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1031,7 +1031,7 @@ def test_embedding_block_ptr(self): torch.nn.functional.embedding(*args), block_sizes=[8, 64], indexing="block_ptr", - use_yz_grid=True, + pid_type="xyz", ), """\ from __future__ import annotations diff --git a/test/test_generate_ast.py b/test/test_generate_ast.py index a4bd1006..bf794256 100644 --- a/test/test_generate_ast.py +++ b/test/test_generate_ast.py @@ -199,7 +199,7 @@ def test_add3d_xy_grid(self): torch.randn([100, 500, 10], device=DEVICE), ) code, result = code_and_output( - basic_kernels.add, args, block_sizes=[16, 16, 16], use_yz_grid=True + basic_kernels.add, args, block_sizes=[16, 16, 16], pid_type="xyz" ) torch.testing.assert_close(result, args[0] + args[1]) self.assertExpectedInline( diff --git a/test/test_persistent_kernels.py b/test/test_persistent_kernels.py new file mode 100644 index 00000000..92d39d60 --- /dev/null +++ b/test/test_persistent_kernels.py @@ -0,0 +1,974 @@ +from __future__ import annotations + +import unittest + +from expecttest import TestCase +import torch + +import helion +from helion._compat import supports_tensor_descriptor +from helion._testing import DEVICE +from helion._testing import code_and_output +import helion.language as hl + + +# Global kernel definitions to avoid duplication +@helion.kernel(use_default_config=True) +def add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.grid(x.size()): + result[tile] = x[tile] + y[tile] + return result + + +@helion.kernel(use_default_config=True) +def matmul_kernel(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + M, K = A.size() + K2, N = B.size() + assert K == K2 + result = A.new_empty([M, N]) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(K): + acc += A[tile_m, tile_k] @ B[tile_k, tile_n] + result[tile_m, tile_n] = acc + return result + + +@helion.kernel(use_default_config=True) +def add_3d_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.grid(x.size()): + result[tile] = x[tile] + y[tile] + return result + + +@helion.kernel(use_default_config=True) +def add1_kernel(x: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[32, 16]): + result[tile] = x[tile] + 1 + return result + + +class TestPersistentKernels(TestCase): + """Test persistent kernel codegen with different PID strategies.""" + + def test_persistent_blocked_simple_add(self): + """Test persistent blocked kernel with simple addition.""" + + args = ( + torch.randn([128, 256], device=DEVICE), + torch.randn([128, 256], device=DEVICE), + ) + + # Test with persistent_blocked + code, result = code_and_output(add_kernel, args, pid_type="persistent_blocked") + + # Check correctness + expected = args[0] + args[1] + torch.testing.assert_close(result, expected) + + # Check that code contains persistent kernel infrastructure + self.assertIn("virtual_pid", code) + self.assertIn("total_pids", code) + + def test_persistent_interleaved_simple_add(self): + """Test persistent interleaved kernel with simple addition.""" + + args = ( + torch.randn([128, 256], device=DEVICE), + torch.randn([128, 256], device=DEVICE), + ) + + # Test with persistent_interleaved + code, result = code_and_output( + add_kernel, args, pid_type="persistent_interleaved" + ) + + # Check correctness + expected = args[0] + args[1] + torch.testing.assert_close(result, expected) + + # Check that code contains persistent kernel infrastructure + self.assertIn("virtual_pid", code) + self.assertIn("total_pids", code) + + def test_persistent_blocked_matmul(self): + """Test persistent blocked kernel with matrix multiplication.""" + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([128, 96], device=DEVICE), + ) + + # Test with persistent_blocked + code_persistent, result_persistent = code_and_output( + matmul_kernel, args, pid_type="persistent_blocked", block_sizes=[32, 32, 32] + ) + + # Test with flat for comparison + code_flat, result_flat = code_and_output( + matmul_kernel, args, pid_type="flat", block_sizes=[32, 32, 32] + ) + + # Persistent and flat should produce identical results + torch.testing.assert_close(result_persistent, result_flat, atol=0, rtol=0) + + # Check correctness against PyTorch + expected = torch.matmul(args[0], args[1]) + torch.testing.assert_close(result_persistent, expected, atol=1e-1, rtol=1e-2) + + # Check that code contains persistent loop structure + self.assertIn("for virtual_pid in tl.range", code_persistent) + self.assertIn("virtual_pid", code_persistent) + + def test_persistent_interleaved_matmul(self): + """Test persistent interleaved kernel with matrix multiplication.""" + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([128, 96], device=DEVICE), + ) + + # Test with persistent_interleaved + code_persistent, result_persistent = code_and_output( + matmul_kernel, + args, + block_sizes=[16, 16, 32], + pid_type="persistent_interleaved", + ) + + # Test with flat for comparison + code_flat, result_flat = code_and_output( + matmul_kernel, + args, + block_sizes=[16, 16, 32], + pid_type="flat", + ) + + # Persistent and flat should produce identical results + torch.testing.assert_close(result_persistent, result_flat, atol=0, rtol=0) + + # Check correctness against PyTorch + expected = torch.matmul(args[0], args[1]) + torch.testing.assert_close(result_persistent, expected, atol=1e-1, rtol=1e-2) + + # Check that code contains persistent loop structure + self.assertIn("for virtual_pid in tl.range", code_persistent) + self.assertIn("virtual_pid", code_persistent) + + def test_persistent_blocked_3d(self): + """Test persistent blocked kernel with 3D tensor.""" + + args = ( + torch.randn([32, 64, 48], device=DEVICE), + torch.randn([32, 64, 48], device=DEVICE), + ) + + # Test with persistent_blocked + code_persistent, result_persistent = code_and_output( + add_3d_kernel, args, pid_type="persistent_blocked" + ) + + # Test with flat for comparison + code_flat, result_flat = code_and_output(add_3d_kernel, args, pid_type="flat") + + # Persistent and flat should produce identical results + torch.testing.assert_close(result_persistent, result_flat, atol=0, rtol=0) + + # Check correctness against expected + expected = args[0] + args[1] + torch.testing.assert_close(result_persistent, expected) + + # Check that code contains persistent kernel infrastructure with 3D decomposition + self.assertIn("virtual_pid", code_persistent) + self.assertIn("num_blocks_0", code_persistent) + self.assertIn("num_blocks_1", code_persistent) + + def test_persistent_interleaved_3d(self): + """Test persistent interleaved kernel with 3D tensor.""" + + args = ( + torch.randn([32, 64, 48], device=DEVICE), + torch.randn([32, 64, 48], device=DEVICE), + ) + + # Test with persistent_interleaved + code_persistent, result_persistent = code_and_output( + add_3d_kernel, + args, + pid_type="persistent_interleaved", + ) + + # Test with flat for comparison + code_flat, result_flat = code_and_output( + add_3d_kernel, + args, + pid_type="flat", + ) + + # Persistent and flat should produce identical results + torch.testing.assert_close(result_persistent, result_flat, atol=0, rtol=0) + + # Check correctness against expected + expected = args[0] + args[1] + torch.testing.assert_close(result_persistent, expected) + + # Check that code contains persistent kernel infrastructure with 3D decomposition + self.assertIn("virtual_pid", code_persistent) + self.assertIn("num_blocks_0", code_persistent) + self.assertIn("num_blocks_1", code_persistent) + + def test_flat_vs_persistent_blocked_equivalence(self): + """Test that flat and persistent_blocked produce same results.""" + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([64, 128], device=DEVICE), + ) + + # Test with flat + _, result_flat = code_and_output(add_kernel, args, pid_type="flat") + + # Test with persistent_blocked + _, result_persistent = code_and_output( + add_kernel, args, pid_type="persistent_blocked" + ) + + # Should produce identical results + torch.testing.assert_close(result_flat, result_persistent) + + def test_xyz_vs_persistent_interleaved_equivalence(self): + """Test that xyz and persistent_interleaved produce same results.""" + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([64, 128], device=DEVICE), + ) + + # Test with xyz + _, result_xyz = code_and_output(add_kernel, args, pid_type="xyz") + + # Test with persistent_interleaved + _, result_persistent = code_and_output( + add_kernel, args, pid_type="persistent_interleaved" + ) + + # Should produce identical results + torch.testing.assert_close(result_xyz, result_persistent) + + def test_persistent_kernels_with_shared_program_id(self): + """Test persistent kernels with multiple top-level for loops to trigger ForEachProgramID. + + Note: In the current implementation, ForEachProgramID generates if statements at the top level, + and persistent kernels work within each if branch. This is different from the ideal + architecture where persistent kernels would generate while loops containing ForEachProgramID + if statements, but it still provides the hierarchical functionality. + """ + + @helion.kernel(use_default_config=True) + def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + + # First top-level loop - will get its own PID + for tile1 in hl.grid(x.size()): + result1[tile1] = x[tile1] * 2 + + # Second top-level loop - will trigger ForEachProgramID + for tile2 in hl.grid(y.size()): + result2[tile2] = y[tile2] * 3 + + return result1, result2 + + torch.manual_seed(42) # Set seed for reproducible results + args = ( + torch.randn([8, 12], device=DEVICE), + torch.randn([8, 12], device=DEVICE), + ) + + # Test with persistent_blocked + code_blocked, results_blocked = code_and_output( + multi_loop_kernel, args, pid_type="persistent_blocked" + ) + + # Test with persistent_interleaved + code_interleaved, results_interleaved = code_and_output( + multi_loop_kernel, args, pid_type="persistent_interleaved" + ) + + # Test with flat for comparison + code_flat, results_flat = code_and_output( + multi_loop_kernel, args, pid_type="flat" + ) + + # First verify all strategies produce identical results (most important check) + torch.testing.assert_close(results_blocked[0], results_flat[0], atol=0, rtol=0) + torch.testing.assert_close(results_blocked[1], results_flat[1], atol=0, rtol=0) + torch.testing.assert_close( + results_interleaved[0], results_flat[0], atol=0, rtol=0 + ) + torch.testing.assert_close( + results_interleaved[1], results_flat[1], atol=0, rtol=0 + ) + + # Calculate expected results + expected1 = args[0] * 2 + expected2 = args[1] * 3 + + # Check correctness against expected (using flat as reference since all should be identical) + torch.testing.assert_close(results_flat[0], expected1) + torch.testing.assert_close(results_flat[1], expected2) + + # Check that generated code contains ForEachProgramID patterns (not virtual_pid since ForEachProgramID disables persistent loops) + self.assertIn("pid_shared", code_blocked) + self.assertIn("if pid_shared <", code_blocked) + self.assertIn("pid_shared", code_interleaved) + self.assertIn("if pid_shared <", code_interleaved) + + def test_persistent_shared_vs_flat_shared_equivalence(self): + """Test that persistent+ForEachProgramID produces same results as flat+ForEachProgramID.""" + + @helion.kernel(use_default_config=True) + def shared_loops_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + output1 = a.new_empty(a.size()) + output2 = b.new_empty(b.size()) + + # Two top-level loops that will use ForEachProgramID + for tile_a in hl.grid(a.size()): + output1[tile_a] = a[tile_a] + 1.0 + + for tile_b in hl.grid(b.size()): + output2[tile_b] = b[tile_b] * 2.0 + + return output1, output2 + + torch.manual_seed(42) # Set seed for reproducible results + args = ( + torch.randn([8, 12], device=DEVICE), + torch.randn([8, 12], device=DEVICE), + ) + + # Test all strategies with ForEachProgramID + _, results_flat = code_and_output(shared_loops_kernel, args, pid_type="flat") + + _, results_persistent_blocked = code_and_output( + shared_loops_kernel, args, pid_type="persistent_blocked" + ) + + _, results_persistent_interleaved = code_and_output( + shared_loops_kernel, args, pid_type="persistent_interleaved" + ) + + # All strategies should produce identical results + torch.testing.assert_close(results_flat[0], results_persistent_blocked[0]) + torch.testing.assert_close(results_flat[1], results_persistent_blocked[1]) + torch.testing.assert_close(results_flat[0], results_persistent_interleaved[0]) + torch.testing.assert_close(results_flat[1], results_persistent_interleaved[1]) + torch.testing.assert_close( + results_persistent_blocked[0], results_persistent_interleaved[0] + ) + torch.testing.assert_close( + results_persistent_blocked[1], results_persistent_interleaved[1] + ) + + # Verify expected computation + expected1 = args[0] + 1.0 + expected2 = args[1] * 2.0 + torch.testing.assert_close(results_flat[0], expected1) + torch.testing.assert_close(results_flat[1], expected2) + + def test_persistent_kernels_complex_shared_scenario(self): + """Test persistent kernels with a more complex ForEachProgramID scenario.""" + + @helion.kernel(use_default_config=True) + def complex_shared_kernel( + x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + + # First loop: process first input + for tile1 in hl.grid(x.size()): + result1[tile1] = x[tile1] + y[tile1] + + # Second loop: process second input (independent from first) + for tile2 in hl.grid(y.size()): + result2[tile2] = y[tile2] * z[tile2] + + return result1, result2 + + torch.manual_seed(42) # Set seed for reproducible results + args = ( + torch.randn([6, 8], device=DEVICE), + torch.randn([6, 8], device=DEVICE), + torch.randn([6, 8], device=DEVICE), + ) + + # Test persistent strategies + code_blocked, result_blocked = code_and_output( + complex_shared_kernel, args, pid_type="persistent_blocked" + ) + + code_interleaved, result_interleaved = code_and_output( + complex_shared_kernel, args, pid_type="persistent_interleaved" + ) + + # Test with flat for comparison + code_flat, result_flat = code_and_output( + complex_shared_kernel, args, pid_type="flat" + ) + + # All strategies should produce identical results + torch.testing.assert_close(result_blocked, result_flat, atol=0, rtol=0) + torch.testing.assert_close(result_interleaved, result_flat, atol=0, rtol=0) + + # Calculate expected results manually + expected1 = args[0] + args[1] + expected2 = args[1] * args[2] + + # Check correctness against PyTorch + torch.testing.assert_close(result_blocked[0], expected1, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(result_blocked[1], expected2, atol=1e-6, rtol=1e-6) + + # Verify ForEachProgramID structure is working (not virtual_pid loop since ForEachProgramID disables it) + self.assertIn("pid_shared", code_blocked) + self.assertIn("if pid_shared <", code_blocked) + self.assertIn("pid_shared", code_interleaved) + self.assertIn("if pid_shared <", code_interleaved) + + def test_persistent_blocked_with_l2_grouping(self): + """Test persistent blocked kernels work with L2 grouping.""" + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([64, 128], device=DEVICE), + ) + + # Test with persistent_blocked + l2_grouping=8 + code_persistent_l2, result_persistent_l2 = code_and_output( + add_kernel, args, pid_type="persistent_blocked", l2_grouping=8 + ) + + # Test with flat + l2_grouping=8 for comparison + code_flat_l2, result_flat_l2 = code_and_output( + add_kernel, args, pid_type="flat", l2_grouping=8 + ) + + # Test with persistent_blocked alone for comparison + code_persistent, result_persistent = code_and_output( + add_kernel, args, pid_type="persistent_blocked", l2_grouping=1 + ) + + # All should produce identical results + torch.testing.assert_close(result_persistent_l2, result_flat_l2, atol=0, rtol=0) + torch.testing.assert_close( + result_persistent_l2, result_persistent, atol=0, rtol=0 + ) + + # Check correctness against expected + expected = args[0] + args[1] + torch.testing.assert_close(result_persistent_l2, expected) + + # Check that persistent + L2 grouping code contains both features + self.assertIn("for virtual_pid in tl.range", code_persistent_l2) + self.assertIn("num_pid_in_group", code_persistent_l2) + self.assertIn("group_id", code_persistent_l2) + # Check that NUM_SM is used in device code and get_num_sm() in host code + self.assertIn("_NUM_SM: tl.constexpr", code_persistent_l2) + self.assertIn("helion.runtime.get_num_sm(", code_persistent_l2) + + def test_shared_program_id_with_persistent_basic_functionality(self): + """Test that ForEachProgramID + persistent kernels generate correct code structure.""" + + @helion.kernel(use_default_config=True) + def multi_add_kernel( + x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + + # Two top-level loops to trigger ForEachProgramID + for tile1 in hl.grid(x.size()): + result1[tile1] = x[tile1] + 1.0 + + for tile2 in hl.grid(y.size()): + result2[tile2] = y[tile2] * 2.0 + + return result1, result2 + + torch.manual_seed(42) # Set seed for reproducible results + args = ( + torch.randn([8, 8], device=DEVICE), + torch.randn([8, 8], device=DEVICE), + ) + + # Test persistent + ForEachProgramID + code_persistent_shared, result_persistent_shared = code_and_output( + multi_add_kernel, args, pid_type="persistent_blocked" + ) + + # Check correctness - both results should be correct + expected1 = args[0] + 1.0 + expected2 = args[1] * 2.0 + + # Note: When persistent kernels are used with ForEachProgramID (multiple loops), + # the system correctly falls back to ForEachProgramID behavior for correctness. + # Both results should be computed correctly. + + torch.testing.assert_close(result_persistent_shared[0], expected1) + torch.testing.assert_close(result_persistent_shared[1], expected2) + + # Check that code contains persistent loop with ForEachProgramID structure + # The new implementation correctly combines persistent kernels with ForEachProgramID + self.assertIn( + "for virtual_pid in tl.range(start_pid, end_pid)", code_persistent_shared + ) + self.assertIn("pid_shared = virtual_pid", code_persistent_shared) + self.assertIn("if pid_shared <", code_persistent_shared) + # Should have the combined total calculation + self.assertIn( + "total_pids = x_size_0 * x_size_1 + y_size_0 * y_size_1", + code_persistent_shared, + ) + # Grid should use SM count for persistent kernels + # Check that NUM_SM is used in device code and get_num_sm() in host code + self.assertIn("_NUM_SM: tl.constexpr", code_persistent_shared) + self.assertIn("helion.runtime.get_num_sm(", code_persistent_shared) + + def test_simple_persistent_kernels_work(self): + """Test that simple persistent kernels compile and run correctly.""" + + @helion.kernel(use_default_config=True) + def simple_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[32, 16]): + result[tile] = x[tile] + y[tile] + return result + + args = ( + torch.randn([8, 12], device=DEVICE), + torch.randn([8, 12], device=DEVICE), + ) + expected = args[0] + args[1] + + # Test persistent_blocked + code_blocked, result_blocked = code_and_output( + simple_add, args, pid_type="persistent_blocked" + ) + torch.testing.assert_close(result_blocked, expected) + + # Verify correct grid size and loop structure + # Check that NUM_SM is used in device code and get_num_sm() in host code + self.assertIn("_NUM_SM: tl.constexpr", code_blocked) + self.assertIn("helion.runtime.get_num_sm(", code_blocked) + self.assertIn("for virtual_pid in tl.range", code_blocked) + + # Test persistent_interleaved + code_interleaved, result_interleaved = code_and_output( + simple_add, args, pid_type="persistent_interleaved" + ) + torch.testing.assert_close(result_interleaved, expected) + + # Verify correct grid size and loop structure + # Check that NUM_SM is used in device code and get_num_sm() in host code + self.assertIn("_NUM_SM: tl.constexpr", code_interleaved) + self.assertIn("helion.runtime.get_num_sm(", code_interleaved) + self.assertIn("for virtual_pid in tl.range", code_interleaved) + + def test_multi_loop_persistent_with_shared_program_id(self): + """Test that multi-loop persistent kernels with ForEachProgramID work correctly. + + This is a regression test for the bug where multi-loop kernels with persistent + strategies would generate incorrect code with variable scoping issues. + """ + + @helion.kernel(use_default_config=True) + def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + + # First loop - will get its own PID + for tile1 in hl.tile(x.size(), block_size=[16, 8]): + result1[tile1] = x[tile1] * 2 + + # Second loop - will trigger ForEachProgramID + for tile2 in hl.tile(y.size(), block_size=[16, 8]): + result2[tile2] = y[tile2] * 3 + + return result1, result2 + + args = (torch.randn([4, 6], device=DEVICE), torch.randn([4, 6], device=DEVICE)) + expected1 = args[0] * 2 + expected2 = args[1] * 3 + + # Test with persistent_blocked - this was failing before the fix + code_blocked, results_blocked = code_and_output( + multi_loop_kernel, args, pid_type="persistent_blocked" + ) + torch.testing.assert_close(results_blocked[0], expected1) + torch.testing.assert_close(results_blocked[1], expected2) + + # Verify ForEachProgramID structure is present + self.assertIn("pid_shared", code_blocked) + self.assertIn("if pid_shared <", code_blocked) + + # Test with persistent_interleaved + code_interleaved, results_interleaved = code_and_output( + multi_loop_kernel, args, pid_type="persistent_interleaved" + ) + torch.testing.assert_close(results_interleaved[0], expected1) + torch.testing.assert_close(results_interleaved[1], expected2) + + # Verify ForEachProgramID structure is present + self.assertIn("pid_shared", code_interleaved) + self.assertIn("if pid_shared <", code_interleaved) + + def test_persistent_grid_size_correctness(self): + """Test that persistent kernels use NUM_SMS grid size, not full grid size.""" + + @helion.kernel(use_default_config=True) + def test_kernel(x: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[32, 16]): + result[tile] = x[tile] + 1 + return result + + args = (torch.randn([64, 96], device=DEVICE),) + + # Get codes for different strategies + code_flat, _ = code_and_output(test_kernel, args, pid_type="flat") + code_persistent_blocked, _ = code_and_output( + test_kernel, args, pid_type="persistent_blocked" + ) + code_persistent_interleaved, _ = code_and_output( + test_kernel, args, pid_type="persistent_interleaved" + ) + + # Extract grid sizes from kernel calls - look for the pattern _kernel[grid_size,] + import re + + # Use a more flexible pattern that captures everything between _kernel[ and ,] + flat_grid_match = re.search(r"_kernel\[([^\]]+),\]", code_flat) + persistent_blocked_grid_match = re.search( + r"_kernel\[([^\]]+),\]", code_persistent_blocked + ) + persistent_interleaved_grid_match = re.search( + r"_kernel\[([^\]]+),\]", code_persistent_interleaved + ) + + self.assertIsNotNone(flat_grid_match, "Could not find grid size in flat code") + self.assertIsNotNone( + persistent_blocked_grid_match, + "Could not find grid size in persistent blocked code", + ) + self.assertIsNotNone( + persistent_interleaved_grid_match, + "Could not find grid size in persistent interleaved code", + ) + + flat_grid = flat_grid_match.group(1).rstrip(",") # Remove trailing comma + persistent_blocked_grid = persistent_blocked_grid_match.group(1).rstrip(",") + persistent_interleaved_grid = persistent_interleaved_grid_match.group(1).rstrip( + "," + ) + + # Flat should use the full grid size calculation + self.assertIn("triton.cdiv", flat_grid) + + # Persistent kernels should use NUM_SMS + self.assertEqual( + persistent_blocked_grid, + "_NUM_SM", + ) + self.assertEqual( + persistent_interleaved_grid, + "_NUM_SM", + ) + + def test_persistent_loop_variable_names(self): + """Test that persistent kernels use correct virtual_pid variable names.""" + + @helion.kernel(use_default_config=True) + def test_kernel(x: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[32, 16]): + result[tile] = x[tile] + 1 + return result + + args = (torch.randn([32, 48], device=DEVICE),) + + # Test blocked strategy + code_blocked, _ = code_and_output( + test_kernel, args, pid_type="persistent_blocked" + ) + + # Should have the correct loop structure + self.assertIn("for virtual_pid in tl.range(start_pid, end_pid):", code_blocked) + self.assertIn("pid_0 = virtual_pid %", code_blocked) + self.assertIn("pid_1 = virtual_pid //", code_blocked) + + # Test interleaved strategy + code_interleaved, _ = code_and_output( + test_kernel, args, pid_type="persistent_interleaved" + ) + + # Should have the correct loop structure + self.assertIn( + "for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM):", + code_interleaved, + ) + self.assertIn("pid_0 = virtual_pid %", code_interleaved) + self.assertIn("pid_1 = virtual_pid //", code_interleaved) + + def test_persistent_1d_tiling(self): + """Test persistent kernels with 1D tiling.""" + + @helion.kernel(use_default_config=True) + def vector_add_1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[128]): + result[tile] = x[tile] + y[tile] + return result + + args = ( + torch.randn([1024], device=DEVICE), + torch.randn([1024], device=DEVICE), + ) + expected = args[0] + args[1] + + # Test persistent_blocked with 1D + code_blocked, result_blocked = code_and_output( + vector_add_1d, args, pid_type="persistent_blocked" + ) + torch.testing.assert_close(result_blocked, expected) + + # Verify 1D persistent loop structure + self.assertIn("for virtual_pid in tl.range", code_blocked) + self.assertIn("pid_0 = virtual_pid", code_blocked) + self.assertNotIn("pid_1", code_blocked) # Should not have pid_1 for 1D + + # Test persistent_interleaved with 1D + code_interleaved, result_interleaved = code_and_output( + vector_add_1d, args, pid_type="persistent_interleaved" + ) + torch.testing.assert_close(result_interleaved, expected) + + # Verify 1D persistent loop structure + self.assertIn("for virtual_pid in tl.range", code_interleaved) + self.assertIn("pid_0 = virtual_pid", code_interleaved) + self.assertNotIn("pid_1", code_interleaved) # Should not have pid_1 for 1D + + # Test correctness vs flat + code_flat, result_flat = code_and_output(vector_add_1d, args, pid_type="flat") + torch.testing.assert_close(result_blocked, result_flat, atol=0, rtol=0) + torch.testing.assert_close(result_interleaved, result_flat, atol=0, rtol=0) + + def test_persistent_interleaved_with_l2_grouping_single_loop(self): + """Test persistent_interleaved with l2_grouping (2D iteration space) - single loop case.""" + + @helion.kernel(use_default_config=True) + def single_loop_l2_kernel(x: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + # Single top-level hl.tile loop with 2D iteration space + for tile in hl.tile(x.size(), block_size=[16, 16]): + result[tile] = x[tile] * 2.0 + return result + + args = (torch.randn([64, 128], device=DEVICE),) + + # Test with persistent_interleaved + l2_grouping=4 (requires 2D iteration space) + code, result = code_and_output( + single_loop_l2_kernel, + args, + pid_type="persistent_interleaved", + l2_grouping=4, + ) + + # Check correctness + expected = args[0] * 2.0 + torch.testing.assert_close(result, expected) + + # Verify code contains persistent_interleaved featur + self.assertIn("for virtual_pid in tl.range", code) + self.assertIn("_NUM_SM", code) + + # Verify L2 grouping features are present + self.assertIn("num_pid_in_group", code) + self.assertIn("group_id", code) + + # Verify 2D iteration space variables + self.assertIn("pid_0 = ", code) + self.assertIn("pid_1 = ", code) + + # Test against flat for correctness comparison + code_flat, result_flat = code_and_output( + single_loop_l2_kernel, args, pid_type="flat", l2_grouping=4 + ) + torch.testing.assert_close(result, result_flat, atol=0, rtol=0) + + def test_persistent_interleaved_multiple_loops_without_l2_grouping(self): + """Test persistent_interleaved with multiple top-level hl.tile loops (without l2_grouping).""" + + @helion.kernel(use_default_config=True) + def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + + # First top-level hl.tile loop + for tile1 in hl.tile(x.size(), block_size=[16, 16]): + result1[tile1] = x[tile1] * 2.0 + + # Second top-level hl.tile loop - triggers ForEachProgramID + for tile2 in hl.tile(y.size(), block_size=[16, 16]): + result2[tile2] = y[tile2] + 1.0 + + return result1, result2 + + args = ( + torch.randn([32, 64], device=DEVICE), + torch.randn([32, 64], device=DEVICE), + ) + + # Test with persistent_interleaved (no l2_grouping to avoid current limitations) + code, result = code_and_output( + multi_loop_kernel, + args, + pid_type="persistent_interleaved", + ) + + # Check correctness + expected1 = args[0] * 2.0 + expected2 = args[1] + 1.0 + torch.testing.assert_close(result[0], expected1) + torch.testing.assert_close(result[1], expected2) + + # Verify code contains persistent_interleaved features combined with ForEachProgramID + self.assertIn("for virtual_pid in tl.range", code) + + # Verify ForEachProgramID features (multiple loops) + self.assertIn("pid_shared", code) + self.assertIn("if pid_shared <", code) + + # Test against flat for correctness comparison + code_flat, result_flat = code_and_output( + multi_loop_kernel, args, pid_type="flat" + ) + torch.testing.assert_close(result[0], result_flat[0], atol=0, rtol=0) + torch.testing.assert_close(result[1], result_flat[1], atol=0, rtol=0) + + def test_persistent_interleaved_multiple_loops_with_l2_grouping(self): + """Test persistent_interleaved with multiple top-level hl.tile loops AND l2_grouping (all 3 features combined).""" + + @helion.kernel(use_default_config=True) + def multi_loop_l2_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result1 = x.new_empty(x.size()) + result2 = y.new_empty(y.size()) + result3 = y.new_empty(y.size()) + for tile1 in hl.tile(x.size(), block_size=[16, 16]): + result1[tile1] = x[tile1] * 2.0 + for tile2 in hl.tile(y.size(), block_size=[16, 16]): + result2[tile2] = y[tile2] + 1.0 + for tile3 in hl.tile(y.size(), block_size=[16, 16]): + result3[tile3] = y[tile3] + 2.0 + return result1, result2, result3 + + args = ( + torch.randn([32, 64], device=DEVICE), + torch.randn([32, 64], device=DEVICE), + ) + + # Test with persistent_interleaved + multiple loops + l2_grouping=4 (all 3 features) + code, result = code_and_output( + multi_loop_l2_kernel, + args, + pid_type="persistent_interleaved", + l2_grouping=[2, 4, 2], + ) + + # Check correctness + expected1 = args[0] * 2.0 + expected2 = args[1] + 1.0 + expected3 = args[1] + 2.0 + torch.testing.assert_close(result[0], expected1) + torch.testing.assert_close(result[1], expected2) + torch.testing.assert_close(result[2], expected3) + + # Verify code contains persistent_interleaved features + self.assertIn("for virtual_pid in tl.range", code) + self.assertIn("_NUM_SM", code) + + # Verify L2 grouping features are present + self.assertIn("num_pid_in_group", code) + self.assertIn("group_id", code) + + # Verify ForEachProgramID features (multiple loops) + self.assertIn("pid_shared", code) + self.assertIn("if pid_shared <", code) + + # Verify 2D iteration space variables + self.assertIn("pid_0 = ", code) + self.assertIn("pid_1 = ", code) + + # Test against flat for correctness comparison + code_flat, result_flat = code_and_output( + multi_loop_l2_kernel, args, pid_type="flat", l2_grouping=4 + ) + torch.testing.assert_close(result[0], result_flat[0]) + torch.testing.assert_close(result[1], result_flat[1]) + + @unittest.skipUnless( + supports_tensor_descriptor(), "Tensor descriptors not supported on this device" + ) + def test_persistent_kernels_with_tensor_descriptor_indexing(self): + """Test persistent kernels with indexing='tensor_descriptor'.""" + + @helion.kernel(use_default_config=True) + def tensor_descriptor_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + result = x.new_empty(x.size()) + for tile in hl.tile(x.size(), block_size=[32, 32]): + result[tile] = x[tile] + y[tile] + return result + + args = ( + torch.randn([64, 128], device=DEVICE), + torch.randn([64, 128], device=DEVICE), + ) + + # Test with tensor_descriptor indexing + persistent_blocked + code_blocked, result_blocked = code_and_output( + tensor_descriptor_kernel, + args, + pid_type="persistent_blocked", + indexing="tensor_descriptor", + ) + + # Test with tensor_descriptor indexing + persistent_interleaved + code_interleaved, result_interleaved = code_and_output( + tensor_descriptor_kernel, + args, + pid_type="persistent_interleaved", + indexing="tensor_descriptor", + ) + + # Check correctness + expected = args[0] + args[1] + torch.testing.assert_close(result_blocked, expected) + torch.testing.assert_close(result_interleaved, expected) + + # Verify tensor descriptor features in code + self.assertIn("tl.make_tensor_descriptor", code_blocked) + self.assertIn("tl.make_tensor_descriptor", code_interleaved) + + # Verify persistent kernel features + self.assertIn("for virtual_pid in tl.range", code_blocked) + self.assertIn("for virtual_pid in tl.range", code_interleaved) + + # Verify both produce identical results + torch.testing.assert_close(result_blocked, result_interleaved, atol=0, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_register_tunable.py b/test/test_register_tunable.py index 3941b48d..8e80edd2 100644 --- a/test/test_register_tunable.py +++ b/test/test_register_tunable.py @@ -97,7 +97,7 @@ def kernel_with_int_param(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertExpectedInline( repr(kernel_with_int_param.bind((x,)).config_spec.default_config()), - """helion.Config(block_sizes=[128], num_warps=4, num_stages=3, indexing='pointer', multiplier=3)""", + """helion.Config(block_sizes=[128], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat', multiplier=3)""", ) self.assertExpectedInline( code,