Closed
Description
I'm trying to write Helion to autotune something based on the segment reduction Triton kernel from https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py, which makes use of tl.associative_scan
.
Here's an attempt that failed.
# Triton `parallel_segment_reduction_kernel` adapted from
# https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py
import torch
import triton
import triton.language as tl
import helion
import helion.language as hl
@triton.jit
def combine_fn(left_values, left_indices, right_values, right_indices):
same_segment = left_indices == right_indices
combined_values = tl.where(same_segment, left_values + right_values, right_values)
combined_indices = right_indices
return combined_values, combined_indices
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE": 16},
),
],
key=["C"],
restore_value=["out_ptr"],
)
@triton.jit
def parallel_segment_reduction_kernel(
index, # the input index tensor
in_ptr, # the input tensor
out_ptr, # the output value tensor
E: tl.constexpr, # Number of elements in the input tensor (1d)
C: tl.constexpr, # Number of features in the input tensor (2d)
BLOCK_SIZE: tl.constexpr, # Block size for the scan
):
pid = tl.program_id(axis=0)
offset_pid = pid // C
feature_id = pid % C
offsets = offset_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < E
# Load input data
vals = tl.load(in_ptr + offsets * C + feature_id, mask=mask)
idxs = tl.load(index + offsets, mask=mask)
idxs_next = tl.load(index + offsets + 1, offsets < E - 1)
# Perform an inclusive scan using tl.associative_scan
result_values, _ = tl.associative_scan(
(
vals,
idxs,
),
axis=0,
combine_fn=combine_fn,
)
# if offset % BLOCK_SIZE == -1, it means the last element of the segment
segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
tl.atomic_add(out_ptr + idxs * C + feature_id, result_values, mask & segment_start)
def launch_parallel_reduction(indices, input, num_nodes):
E, C = input.shape
output = torch.zeros((num_nodes, C), dtype=input.dtype, device=input.device)
def grid(META):
return (triton.cdiv(E, META["BLOCK_SIZE"]) * C,)
parallel_segment_reduction_kernel[grid](indices, input, output, E, C)
return output
def helion_combine_fn(left_values, left_indices, right_values, right_indices):
same_segment = left_indices == right_indices
combined_values = torch.where(
same_segment, left_values + right_values, right_values
)
combined_indices = right_indices
return combined_values, combined_indices
@helion.kernel(
use_default_config=True,
)
def segmented_reduction_helion(
indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int
) -> torch.Tensor:
E, C = input_data.shape
output = torch.zeros(
(num_nodes, C), dtype=input_data.dtype, device=input_data.device
)
for tile_e, tile_f in hl.tile([E, C]):
vals = input_data[tile_e, tile_f]
idxs = indices[tile_e]
idxs_next = indices[tile_e.index + 1]
# ^ https://github.com/pytorch-labs/helion/blob/0996865cb3bdc1d0bffa6b8a86cf5fce1a980fb3/test/test_indexing.py#L68
# not sure if this is expected to work
out_vals, _ = torch._higher_order_ops.associative_scan(
helion_combine_fn, (vals, idxs), 0
)
# what's the analogous form for:
# segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
# i.e. unclear how to get `(offsets % BLOCK_SIZE == BLOCK_SIZE - 1)`
mask = idxs != idxs_next
segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0)
hl.atomic_add(output, [idxs, tile_f], segment_vals)
return output
def test_segmented_reduction():
# Create test data
num_nodes = 100
num_edges = 1000
C = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
# Create sorted indices for segmented reduction
indices = torch.randint(0, num_nodes, (num_edges,), device=device).sort()[0]
input_data = torch.randn(num_edges, C, device=device, dtype=dtype)
# Run PyTorch reference (scatter_add equivalent)
pytorch_output = torch.zeros(num_nodes, C, device=device, dtype=dtype)
pytorch_output.scatter_add_(0, indices.unsqueeze(1).expand(-1, C), input_data)
# Test Triton implementation
triton_output = launch_parallel_reduction(indices, input_data, num_nodes)
torch.testing.assert_close(triton_output, pytorch_output, rtol=1e-4, atol=1e-4)
print("Test passed! Triton and PyTorch outputs match.")
# Test Helion implementation
helion_output = segmented_reduction_helion(indices, input_data, num_nodes)
torch.testing.assert_close(helion_output, pytorch_output, rtol=1e-4, atol=1e-4)
print("Test passed! Helion and PyTorch outputs match.")
return True
def main():
test_segmented_reduction()
if __name__ == "__main__":
main()
I wasn't really sure if using torch._higher_order_ops.associative_scan
would get the intended outcome, but just to demonstrate a one-to-one mapping between Triton and a plausible Helion implementation structure. The other part that I wasn't sure of was the mask
segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
specifically, the (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
part.
Would we expect something like this to be supported by Helion, be it with custom workarounds or something out-of-the-box? Thanks!
Stacktrace looks like
Test passed! Triton and PyTorch outputs match.
E0703 11:21:28.382000 19234 site-packages/torch/fx/experimental/recording.py:299] [0/0] failed while running _create_symbolic_sizes_strides_storage_offset(*((64, 64), (64, 1), 0, [False, False], GetItemSource(base=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), index=0, index_is_slice=False)), **{'symbolic_context': StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=GetItemSource(base=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), index=0, index_is_slice=False), shape_env_to_source_to_symbol_cache={})})
Traceback (most recent call last):
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 745, in propagate_call
_CheckForIndexCalls.retry_call(
File "/home/cwtan/nequip-allegro/helion/helion/language/tile_proxy.py", line 166, in retry_call
return fn(*proxy_args, **proxy_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 146, in associative_scan
return torch.compile(associative_scan, fullgraph=True, backend="eager")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
return _compile(
^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1110, in _compile
raise InternalTorchDynamoError(
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
transformations(instructions, code_options)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
tracer.run()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3498, in run
super().run()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
self._call(inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
self.call_function(fn, args, kwargs)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 414, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
self.run()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
self._call(inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
self.call_function(fn, args, kwargs)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
self.run()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2931, in CALL
self._call(inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2925, in _call
self.call_function(fn, args, kwargs)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 184, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3724, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3903, in inline_call_
self.run()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1432, in STORE_FAST
loaded_vt.set_name_hint(name)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 67, in realize
self._cache.realize()
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 33, in realize
self.vt = VariableTracker.build(tx, self.value, source)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 540, in build
return builder.VariableBuilder(tx, source)(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 417, in __call__
vt = self._wrap(value)
^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 600, in _wrap
return type_dispatch(self, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1795, in wrap_tensor
example_value = wrap_to_fake_tensor_and_record(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 3032, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception
return fn()
^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 3033, in <lambda>
lambda: tx.fake_mode.from_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2658, in from_tensor
return self.fake_tensor_converter.from_real_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 392, in from_real_tensor
out = self.meta_converter(
^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1889, in __call__
r = self.meta_tensor(
^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1656, in meta_tensor
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 930, in sym_sizes_strides_storage_offset
return shape_env._create_symbolic_sizes_strides_storage_offset(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4053, in _create_symbolic_sizes_strides_storage_offset
size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3873, in _produce_dyn_sizes_from_int_tuple
assert all(
^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3874, in <genexpr>
not is_symbolic(val) for val in tensor_size
^^^^^^^^^^^^^^^^
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1902, in is_symbolic
return val.node.is_symbolic()
^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'Integer' object has no attribute 'node'
from user code:
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 150, in associative_scan
leaves, spec = pytree.tree_flatten(xs)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1055, in tree_flatten
treespec = helper(tree, leaves)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in helper
subspecs = [helper(child, leaves) for child in children]
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in <listcomp>
subspecs = [helper(child, leaves) for child in children]
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 150, in <module>
main()
File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 146, in main
test_segmented_reduction()
File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 138, in test_segmented_reduction
helion_output = segmented_reduction_helion(indices, input_data, num_nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 232, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 129, in bind
bound_kernel = BoundKernel(self, args)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/runtime/kernel.py", line 289, in __init__
self.host_function: HostFunction = HostFunction(
^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/host_function.py", line 107, in __init__
propagate_types(self, fake_args)
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 2079, in propagate_types
prop.visit(stmt)
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1971, in visit_For
body = self._loop_body(node.body)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1935, in _loop_body
self.visit(stmt)
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1859, in visit_Assign
type_info = self.visit(node.value)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1475, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 1801, in visit_Call
return func.propagate_call(tuple(args), kwargs, self.origin())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cwtan/nequip-allegro/helion/helion/_compiler/type_propagation.py", line 779, in propagate_call
raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: InternalTorchDynamoError: AttributeError: 'Integer' object has no attribute 'node'
from user code:
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/_higher_order_ops/associative_scan.py", line 150, in associative_scan
leaves, spec = pytree.tree_flatten(xs)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1055, in tree_flatten
treespec = helper(tree, leaves)
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in helper
subspecs = [helper(child, leaves) for child in children]
File "/home/cwtan/micromamba/envs/allegro-dev/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1051, in <listcomp>
subspecs = [helper(child, leaves) for child in children]
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
While processing:
File "/home/cwtan/nequip-allegro/test/segred_triton.py", line 101, in segmented_reduction_helion
out_vals, _ = torch._higher_order_ops.associative_scan(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^