|
8 | 8 | from typing import overload
|
9 | 9 |
|
10 | 10 | import torch
|
11 |
| -from torch._inductor.codegen.simd import constant_repr |
12 |
| -from torch._inductor.runtime.runtime_utils import next_power_of_2 |
13 | 11 | from torch._inductor.runtime.triton_heuristics import get_max_y_grid
|
14 | 12 |
|
15 | 13 | from .. import exc
|
16 | 14 | from .._compiler.ast_extension import ExtendedAST
|
17 | 15 | from .._compiler.ast_extension import LoopType
|
18 | 16 | from .._compiler.ast_extension import expr_from_string
|
19 |
| -from .._compiler.compile_environment import AutoSize |
20 | 17 | from .._compiler.compile_environment import CompileEnvironment
|
21 | 18 | from .._compiler.tile_index_proxy import TileIndexProxy
|
22 | 19 | from .._compiler.type_propagation import GridIndexType
|
|
26 | 23 | from .._compiler.type_propagation import TileIndexType
|
27 | 24 | from .._compiler.type_propagation import TypeInfo
|
28 | 25 | from .._compiler.type_propagation import UnknownType
|
29 |
| -from ..autotuner.config_fragment import assert_integer_power_of_two |
30 | 26 | from ..autotuner.config_spec import ConfigSpec
|
31 | 27 | from ..autotuner.config_spec import FlattenLoopSpec
|
32 | 28 | from ..autotuner.config_spec import L2GroupingSpec
|
|
39 | 35 | from .._compiler.inductor_lowering import CodegenState
|
40 | 36 |
|
41 | 37 |
|
42 |
| -__all__ = ["Tile", "grid", "register_block_size", "register_reduction_dim", "tile"] |
| 38 | +__all__ = ["Tile", "grid", "tile"] |
43 | 39 | Tile = TileIndexProxy
|
44 | 40 |
|
45 | 41 |
|
@@ -372,130 +368,3 @@ def _(state: CodegenState) -> ast.AST:
|
372 | 368 | state.tile_strategy.codegen_grid(state, block_ids)
|
373 | 369 | return expr_from_string("None")
|
374 | 370 | raise AssertionError(f"Expected loop type: {loop_type}")
|
375 |
| - |
376 |
| - |
377 |
| -@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True) |
378 |
| -def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> int: |
379 |
| - """ |
380 |
| - Explicitly register a block size that should be autotuned and can be used for |
381 |
| - allocations and inside hl.tile(..., block_size=...). |
382 |
| -
|
383 |
| - This is useful if you have two loops where you want them to share a block size, |
384 |
| - or if you need to allocate a kernel tensor before the hl.tile() loop. |
385 |
| -
|
386 |
| - The signature can one of: |
387 |
| - hl.register_block_size(max) |
388 |
| - hl.register_block_size(min, max) |
389 |
| -
|
390 |
| - Where min and max are integers that control the range of block_sizes searched by |
391 |
| - the autotuner. Max may be a symbolic shape, but min must be a constant integer. |
392 |
| - """ |
393 |
| - raise exc.NotInsideKernel |
394 |
| - |
395 |
| - |
396 |
| -@_decorators.type_propagation(register_block_size) |
397 |
| -def _( |
398 |
| - min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin |
399 |
| -) -> TypeInfo: |
400 |
| - from .._compiler.type_propagation import SymIntType |
401 |
| - |
402 |
| - min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin) |
403 |
| - min_proxy = _to_proxy(min_type) |
404 |
| - max_proxy = _to_proxy(max_type) |
405 |
| - if not isinstance(max_proxy, (int, torch.SymInt)): |
406 |
| - raise exc.IncorrectTileUsage( |
407 |
| - f"expected max to be an integer or size, got {max_proxy!s}" |
408 |
| - ) |
409 |
| - if not isinstance(min_proxy, int): |
410 |
| - raise exc.IncorrectTileUsage( |
411 |
| - f"expected min to be an integer constant, got {min_proxy!s}" |
412 |
| - ) |
413 |
| - env = CompileEnvironment.current() |
414 |
| - result = TileIndexType.allocate(AutoSize(), origin) |
415 |
| - loop_spec = env.config_spec.block_sizes.block_id_lookup(result.block_id) |
416 |
| - loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy)) |
417 |
| - loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy)) |
418 |
| - block_id = result.block_id |
419 |
| - return SymIntType(origin, env.block_sizes[block_id].var) |
420 |
| - |
421 |
| - |
422 |
| -def _block_id_from_state(state: CodegenState) -> int: |
423 |
| - """Extract the block_id from the current state for nodes hl.register_block_size.""" |
424 |
| - from .._compiler.type_propagation import SymIntType |
425 |
| - |
426 |
| - env = CompileEnvironment.current() |
427 |
| - if state.fx_node is not None: |
428 |
| - val = state.fx_node.meta["val"] |
429 |
| - assert isinstance(val, SymIntType) |
430 |
| - block_id = env.get_block_id(val.value) |
431 |
| - assert block_id is not None |
432 |
| - return block_id |
433 |
| - current_node = ExtendedAST.current()[-1] |
434 |
| - type_info = current_node._type_info |
435 |
| - assert isinstance(type_info, SymIntType) |
436 |
| - block_id = env.get_block_id(type_info.value) |
437 |
| - assert block_id is not None |
438 |
| - return block_id |
439 |
| - |
440 |
| - |
441 |
| -@_decorators.codegen(register_block_size) |
442 |
| -def _(state: CodegenState) -> ast.AST: |
443 |
| - env = CompileEnvironment.current() |
444 |
| - block_size = env.config_spec.block_sizes.config_get( |
445 |
| - state.config.block_sizes, _block_id_from_state(state) |
446 |
| - ) |
447 |
| - assert block_size is not None |
448 |
| - return expr_from_string(constant_repr(block_size)) |
449 |
| - |
450 |
| - |
451 |
| -@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True) |
452 |
| -def register_reduction_dim( |
453 |
| - size: int, |
454 |
| -) -> int: |
455 |
| - """ |
456 |
| - Explicitly register a reduction dimension that should be used for reduction operations. |
457 |
| -
|
458 |
| - This is useful when you need to allocate a dimension for reduction that isn't |
459 |
| - automatically inferred from a slice operation. The registered dimension can be |
460 |
| - used for allocations and operations that require knowing the reduction size upfront. |
461 |
| -
|
462 |
| - :param size: An integer representing the reduction dimension size. |
463 |
| - :return: A SymInt object representing the reduction dimension size. |
464 |
| - """ |
465 |
| - raise exc.NotInsideKernel |
466 |
| - |
467 |
| - |
468 |
| -@_decorators.type_propagation(register_reduction_dim) |
469 |
| -def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo: |
470 |
| - from .._compiler.compile_environment import CompileEnvironment |
471 |
| - from .._compiler.type_propagation import SymIntType |
472 |
| - |
473 |
| - try: |
474 |
| - proxy_sizes = sizes.proxy() |
475 |
| - if not isinstance(proxy_sizes, int | torch.SymInt): |
476 |
| - raise NotImplementedError |
477 |
| - except NotImplementedError: |
478 |
| - raise exc.TypePropagationError( |
479 |
| - UnknownType( |
480 |
| - origin, |
481 |
| - f"register_reduction_dim() expected int or list[int], got {sizes!s}", |
482 |
| - chained_from=sizes, |
483 |
| - ) |
484 |
| - ) from None |
485 |
| - |
486 |
| - env = CompileEnvironment.current() |
487 |
| - |
488 |
| - rdim = env.allocate_reduction_dimension(proxy_sizes) |
489 |
| - return SymIntType(origin, rdim.var) |
490 |
| - |
491 |
| - |
492 |
| -@_decorators.codegen(register_reduction_dim) |
493 |
| -def _(state: CodegenState) -> ast.AST: |
494 |
| - """Generate code for register_reduction_dim - return the size expression""" |
495 |
| - from .._compiler.type_propagation import SymIntType |
496 |
| - |
497 |
| - current_node = ExtendedAST.current()[-1] |
498 |
| - type_info = current_node._type_info |
499 |
| - |
500 |
| - assert isinstance(type_info, SymIntType) |
501 |
| - return current_node.args[0] # pyre-ignore[16] |
0 commit comments