Skip to content

Parallel segment reduction using tl.associative_scan #237

Closed
@cw-tan

Description

@cw-tan

I'm trying to write Helion to autotune something based on the segment reduction Triton kernel from https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py, which makes use of tl.associative_scan.

Here's an attempt that failed.

# Triton `parallel_segment_reduction_kernel` adapted from
# https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py

import torch
import triton
import triton.language as tl

import helion
import helion.language as hl


@triton.jit
def combine_fn(left_values, left_indices, right_values, right_indices):
    same_segment = left_indices == right_indices
    combined_values = tl.where(same_segment, left_values + right_values, right_values)
    combined_indices = right_indices
    return combined_values, combined_indices


@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE": 16},
        ),
    ],
    key=["C"],
    restore_value=["out_ptr"],
)
@triton.jit
def parallel_segment_reduction_kernel(
    index,  # the input index tensor
    in_ptr,  # the input tensor
    out_ptr,  # the output value tensor
    E: tl.constexpr,  # Number of elements in the input tensor (1d)
    C: tl.constexpr,  # Number of features in the input tensor (2d)
    BLOCK_SIZE: tl.constexpr,  # Block size for the scan
):
    pid = tl.program_id(axis=0)
    offset_pid = pid // C
    feature_id = pid % C
    offsets = offset_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < E

    # Load input data
    vals = tl.load(in_ptr + offsets * C + feature_id, mask=mask)
    idxs = tl.load(index + offsets, mask=mask)
    idxs_next = tl.load(index + offsets + 1, offsets < E - 1)

    # Perform an inclusive scan using tl.associative_scan
    result_values, _ = tl.associative_scan(
        (
            vals,
            idxs,
        ),
        axis=0,
        combine_fn=combine_fn,
    )
    # if offset % BLOCK_SIZE == -1, it means the last element of the segment
    segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
    tl.atomic_add(out_ptr + idxs * C + feature_id, result_values, mask & segment_start)


def launch_parallel_reduction(indices, input, num_nodes):
    E, C = input.shape
    output = torch.zeros((num_nodes, C), dtype=input.dtype, device=input.device)

    def grid(META):
        return (triton.cdiv(E, META["BLOCK_SIZE"]) * C,)

    parallel_segment_reduction_kernel[grid](indices, input, output, E, C)
    return output


def helion_combine_fn(left_values, left_indices, right_values, right_indices):
    same_segment = left_indices == right_indices
    combined_values = torch.where(
        same_segment, left_values + right_values, right_values
    )
    combined_indices = right_indices
    return combined_values, combined_indices


@helion.kernel(
    use_default_config=True,
)
def segmented_reduction_helion(
    indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int
) -> torch.Tensor:
    E, C = input_data.shape
    output = torch.zeros(
        (num_nodes, C), dtype=input_data.dtype, device=input_data.device
    )

    for tile_e, tile_f in hl.tile([E, C]):
        vals = input_data[tile_e, tile_f]
        idxs = indices[tile_e]
        idxs_next = indices[tile_e.index + 1]
        # ^ https://github.com/pytorch-labs/helion/blob/0996865cb3bdc1d0bffa6b8a86cf5fce1a980fb3/test/test_indexing.py#L68

        # not sure if this is expected to work
        out_vals, _ = torch._higher_order_ops.associative_scan(
            helion_combine_fn, (vals, idxs), 0
        )

        # what's the analogous form for:
        # segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
        # i.e. unclear how to get `(offsets % BLOCK_SIZE == BLOCK_SIZE - 1)`
        mask = idxs != idxs_next
        segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0)
        hl.atomic_add(output, [idxs, tile_f], segment_vals)

    return output


def test_segmented_reduction():
    # Create test data
    num_nodes = 100
    num_edges = 1000
    C = 32

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32

    # Create sorted indices for segmented reduction
    indices = torch.randint(0, num_nodes, (num_edges,), device=device).sort()[0]
    input_data = torch.randn(num_edges, C, device=device, dtype=dtype)

    # Run PyTorch reference (scatter_add equivalent)
    pytorch_output = torch.zeros(num_nodes, C, device=device, dtype=dtype)
    pytorch_output.scatter_add_(0, indices.unsqueeze(1).expand(-1, C), input_data)

    # Test Triton implementation
    triton_output = launch_parallel_reduction(indices, input_data, num_nodes)
    torch.testing.assert_close(triton_output, pytorch_output, rtol=1e-4, atol=1e-4)
    print("Test passed! Triton and PyTorch outputs match.")

    # Test Helion implementation
    helion_output = segmented_reduction_helion(indices, input_data, num_nodes)
    torch.testing.assert_close(helion_output, pytorch_output, rtol=1e-4, atol=1e-4)
    print("Test passed! Helion and PyTorch outputs match.")

    return True


def main():
    test_segmented_reduction()


if __name__ == "__main__":
    main()

I wasn't really sure if using torch._higher_order_ops.associative_scan would get the intended outcome, but just to demonstrate a one-to-one mapping between Triton and a plausible Helion implementation structure. The other part that I wasn't sure of was the mask

segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)

specifically, the (offsets % BLOCK_SIZE == BLOCK_SIZE - 1) part.

Would we expect something like this to be supported by Helion, be it with custom workarounds or something out-of-the-box? Thanks!

Stacktrace looks like

Test passed! Triton and PyTorch outputs match.
E0703 11:21:28.382000 19234 site-packages/torch/fx/experimental/recording.py:299] [0/0] failed while running _create_symbolic_sizes_strides_storage_offset(*((64, 64), (64, 1), 0, [False, False], GetItemSource(base=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), index=0, index_is_slice=False)), **{'symbolic_context': StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=GetItemSource(base=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), index=0, index_is_slice=False), shape_env_to_source_to_symbol_cache={})})
Traceback (most recent call last):
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 745, in propagate_call
    _CheckForIndexCalls.retry_call(
  File "/home/cwtan/nequip-allegro/helion/helion/language/tile_proxy.py", line 166, in retry_call
    return fn(*proxy_args, **proxy_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 146, in associative_scan
    return torch.compile(associative_scan, fullgraph=True, backend="eager")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1110, in _compile
    raise InternalTorchDynamoError(
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
    transformations(instructions, code_options)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
    tracer.run()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3498, in run
    super().run()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
    self._call(inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
    self.call_function(fn, args, kwargs)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 414, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
    self.run()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
    self._call(inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
    self.call_function(fn, args, kwargs)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
    self.run()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
    self._call(inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
    self.call_function(fn, args, kwargs)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
    self.run()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1432, in STORE_FAST
    loaded_vt.set_name_hint(name)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
                   ^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 67, in realize
    self._cache.realize()
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 33, in realize
    self.vt = VariableTracker.build(tx, self.value, source)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 540, in build
    return builder.VariableBuilder(tx, source)(value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 417, in __call__
    vt = self._wrap(value)
         ^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 600, in _wrap
    return type_dispatch(self, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1795, in wrap_tensor
    example_value = wrap_to_fake_tensor_and_record(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 3032, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception
    return fn()
           ^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 3033, in <lambda>
    lambda: tx.fake_mode.from_tensor(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2658, in from_tensor
    return self.fake_tensor_converter.from_real_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 392, in from_real_tensor
    out = self.meta_converter(
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1889, in __call__
    r = self.meta_tensor(
        ^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1656, in meta_tensor
    ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 930, in sym_sizes_strides_storage_offset
    return shape_env._create_symbolic_sizes_strides_storage_offset(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4053, in _create_symbolic_sizes_strides_storage_offset
    size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3873, in _produce_dyn_sizes_from_int_tuple
    assert all(
           ^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3874, in <genexpr>
    not is_symbolic(val) for val in tensor_size
        ^^^^^^^^^^^^^^^^
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1902, in is_symbolic
    return val.node.is_symbolic()
           ^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'Integer' object has no attribute 'node'

from user code:
   File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 150, in associative_scan
    leaves, spec = pytree.tree_flatten(xs)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1055, in tree_flatten
    treespec = helper(tree, leaves)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in helper
    subspecs = [helper(child, leaves) for child in children]
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in <listcomp>
    subspecs = [helper(child, leaves) for child in children]

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 150, in <module>
    main()
  File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 146, in main
    test_segmented_reduction()
  File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 138, in test_segmented_reduction
    helion_output = segmented_reduction_helion(indices, input_data, num_nodes)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 232, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 129, in bind
    bound_kernel = BoundKernel(self, args)
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 289, in __init__
    self.host_function: HostFunction = HostFunction(
                                       ^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/host_function.py", line 107, in __init__
    propagate_types(self, fake_args)
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 2079, in propagate_types
    prop.visit(stmt)
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1971, in visit_For
    body = self._loop_body(node.body)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1935, in _loop_body
    self.visit(stmt)
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1859, in visit_Assign
    type_info = self.visit(node.value)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1801, in visit_Call
    return func.propagate_call(tuple(args), kwargs, self.origin())
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 779, in propagate_call
    raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: InternalTorchDynamoError: AttributeError: 'Integer' object has no attribute 'node'

from user code:
   File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 150, in associative_scan
    leaves, spec = pytree.tree_flatten(xs)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1055, in tree_flatten
    treespec = helper(tree, leaves)
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in helper
    subspecs = [helper(child, leaves) for child in children]
  File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in <listcomp>
    subspecs = [helper(child, leaves) for child in children]

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

While processing:
  File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 101, in segmented_reduction_helion
    out_vals, _ = torch._higher_order_ops.associative_scan(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions