|
6 | 6 | from torch.fx import has_side_effect
|
7 | 7 |
|
8 | 8 | from .. import exc
|
| 9 | +from .._compiler.indexing_strategy import SubscriptIndexing |
9 | 10 | from . import _decorators
|
10 | 11 |
|
11 | 12 | if TYPE_CHECKING:
|
12 | 13 | import ast
|
13 | 14 |
|
14 | 15 | from .._compiler.inductor_lowering import CodegenState
|
15 | 16 |
|
| 17 | +__all__ = ["signal", "wait"] |
| 18 | + |
16 | 19 |
|
17 | 20 | @has_side_effect
|
18 | 21 | @_decorators.api(tiles_as_sizes=True)
|
@@ -146,3 +149,143 @@ def _(state: CodegenState) -> ast.AST:
|
146 | 149 | signal=signal_expr,
|
147 | 150 | update=update_expr,
|
148 | 151 | )
|
| 152 | + |
| 153 | + |
| 154 | +@has_side_effect |
| 155 | +@_decorators.api(tiles_as_sizes=True) |
| 156 | +def signal( |
| 157 | + signal_pad: torch.Tensor, |
| 158 | + index: list[object], |
| 159 | + signal: int = 1, |
| 160 | + wait_for: int | None = None, |
| 161 | + op: str = "atomic_xchg", |
| 162 | + sem: str = "release", |
| 163 | + scope: str = "gpu", |
| 164 | + skip_sync: bool = False, |
| 165 | +) -> torch.Tensor: |
| 166 | + """Set the signal_pad slice to the signal value. |
| 167 | + Args: |
| 168 | + signal_pad: The signal pad to signal |
| 169 | + index: Indices to index into the signal_pad tensor |
| 170 | + signal: the value to send |
| 171 | + wait_for: The value to wait for before sending the signal. Only valid for op = 'atomic_cas'. |
| 172 | + op: The memory op for acquring the lock (default: 'atomic_xchg') |
| 173 | + sem: The memory sematic for acquring the lock (default: 'release') |
| 174 | + scope: The scope of the lock (default: 'gpu') |
| 175 | + skip_sync: Skip the syncthreads before sending signal (default: False) |
| 176 | + """ |
| 177 | + raise exc.NotInsideKernel |
| 178 | + |
| 179 | + |
| 180 | +@_decorators.prepare_args(signal) |
| 181 | +def _( |
| 182 | + signal_pad: torch.Tensor, |
| 183 | + index: list[object], |
| 184 | + signal: int = 1, |
| 185 | + wait_for: int | None = None, |
| 186 | + op: str = "atomic_xchg", |
| 187 | + sem: str = "release", |
| 188 | + scope: str = "gpu", |
| 189 | + skip_sync: bool = False, |
| 190 | +) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: |
| 191 | + from helion.language.tile_proxy import Tile |
| 192 | + |
| 193 | + valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"} |
| 194 | + valid_sems = {"relaxed", "release", "acq_rel"} |
| 195 | + valid_scopes = {"sys", "gpu"} |
| 196 | + |
| 197 | + if op not in valid_ops: |
| 198 | + raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ") |
| 199 | + |
| 200 | + if op == "atomic_cas" and wait_for is None: |
| 201 | + raise ValueError( |
| 202 | + f"{op} without a wait_for value. Do you want to use 'atomic_add' or 'atomic_xchg' instead? " |
| 203 | + ) |
| 204 | + if op in {"atomic_add", "atomic_xchg"} and wait_for is not None: |
| 205 | + raise ValueError( |
| 206 | + f"{op} with a wait_for value. Do you want to use 'atomic_cas' instead? " |
| 207 | + ) |
| 208 | + |
| 209 | + if sem not in valid_sems: |
| 210 | + raise ValueError( |
| 211 | + f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." |
| 212 | + ) |
| 213 | + |
| 214 | + if scope not in valid_scopes: |
| 215 | + raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") |
| 216 | + |
| 217 | + index = Tile._prepare_index(index) |
| 218 | + index = Tile._tiles_to_sizes(index) |
| 219 | + |
| 220 | + return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync) |
| 221 | + |
| 222 | + |
| 223 | +@_decorators.register_fake(signal) |
| 224 | +def _( |
| 225 | + signal_pad: torch.Tensor, |
| 226 | + index: list[object], |
| 227 | + signal: int = 1, |
| 228 | + wait_for: int | None = None, |
| 229 | + op: str = "atomic_xchg", |
| 230 | + sem: str = "release", |
| 231 | + scope: str = "gpu", |
| 232 | + skip_sync: bool = False, |
| 233 | +) -> torch.Tensor: |
| 234 | + return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index)) |
| 235 | + |
| 236 | + |
| 237 | +@_decorators.codegen(signal) |
| 238 | +def _(state: CodegenState) -> ast.AST: |
| 239 | + import ast |
| 240 | + |
| 241 | + from .._compiler.ast_extension import expr_from_string |
| 242 | + from .._compiler.indexing_strategy import SubscriptIndexing |
| 243 | + |
| 244 | + signal_pad = state.proxy_arg(0) |
| 245 | + index = state.proxy_arg(1) |
| 246 | + signal = state.proxy_arg(2) |
| 247 | + wait_for = state.proxy_arg(3) |
| 248 | + op = state.proxy_arg(4) |
| 249 | + sem = state.proxy_arg(5) |
| 250 | + scope = state.proxy_arg(6) |
| 251 | + skip_sync = state.proxy_arg(7) |
| 252 | + |
| 253 | + assert isinstance(signal_pad, torch.Tensor) |
| 254 | + assert isinstance(index, list) |
| 255 | + |
| 256 | + indices = SubscriptIndexing.create(state, signal_pad, index) |
| 257 | + signal_pad_name = state.device_function.tensor_arg(signal_pad).name |
| 258 | + |
| 259 | + signal_expr = ast.Constant(value=signal) |
| 260 | + if wait_for is not None: |
| 261 | + wait_for_expr = ast.Constant(value=wait_for) |
| 262 | + else: |
| 263 | + wait_for_expr = ast.Constant(value=0) |
| 264 | + skip_sync_expr = ast.Constant(value=skip_sync) |
| 265 | + assert type(op) is str |
| 266 | + assert type(sem) is str |
| 267 | + assert type(scope) is str |
| 268 | + |
| 269 | + if op == "atomic_cas": |
| 270 | + bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) |
| 271 | + is_scalar = len(bar_tensor_shape) == 0 |
| 272 | + if is_scalar: |
| 273 | + call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" |
| 274 | + else: |
| 275 | + call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)" |
| 276 | + |
| 277 | + return expr_from_string( |
| 278 | + call_triton_wait_signal, |
| 279 | + offset=indices.index_expr, |
| 280 | + wait_for=wait_for_expr, |
| 281 | + signal=signal_expr, |
| 282 | + skip_sync=skip_sync_expr, |
| 283 | + ) |
| 284 | + call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)" |
| 285 | + |
| 286 | + return expr_from_string( |
| 287 | + call_triton_send_signal, |
| 288 | + offset=indices.index_expr, |
| 289 | + signal=signal_expr, |
| 290 | + skip_sync=skip_sync_expr, |
| 291 | + ) |
0 commit comments