Skip to content

TypeError: function takes exactly 27 arguments (20 given) #26

@tai2456

Description

@tai2456

TypeError Traceback (most recent call last)
Cell In[8], line 4
1 for batch in train_loader:
2 x = batch[0].to(device)
----> 4 lls = pc(x, record_cudagraph = True)
5 lls.mean().backward()
6 break

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/model/tensorcircuit.py:228, in TensorCircuit.forward(self, inputs, input_layer_fn, cache, return_cache, record_cudagraph, apply_cudagraph, force_use_bf16, force_use_fp32, propagation_alg, _inner_layers_only, _no_buffer_reset, **kwargs)
226 for idx, layer in enumerate(self.input_layer_group):
227 if input_layer_fn is None:
--> 228 layer(inputs, self.node_mars, **kwargs)
230 elif isinstance(input_layer_fn, str):
231 assert hasattr(layer, input_layer_fn), f"Custom input function {input_layer_fn} not found for layer type {type(layer)}."

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/layer/input_layer.py:274, in InputLayer.forward(self, data, node_mars, params, missing_mask, _batch_first, _apply_missing_mask_only, **kwargs)
271 grid = (triton.cdiv(layer_num_nodes * batch_size, BLOCK_SIZE),)
273 if not _apply_missing_mask_only:
--> 274 self._mars_kernel[grid](
275 params_ptr = self.params,
276 node_mars_ptr = node_mars,
277 data_ptr = data,
278 vids_ptr = self.vids,
279 s_pids_ptr = self.s_pids,
280 metadata_ptr = self.metadata,
281 s_mids_ptr = self.s_mids,
282 fw_local_ids_ptr = fw_local_ids,
283 layer_num_nodes = layer_num_nodes,
284 batch_size = batch_size,
285 num_vars_per_node = self.num_vars_per_node,
286 nv_block_size = triton.next_power_of_2(self.num_vars_per_node),
287 node_offset = node_offset,
288 BLOCK_SIZE = BLOCK_SIZE,
289 partial_eval = 1 if fw_local_ids is not None else 0,
290 num_warps = 8
291 )
292 else:
293 assert missing_mask is not None, "missing_mask should be provided when _apply_missing_mask_only = True."

File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/utils/kernel_launcher.py:84, in FastJITFunction.getitem..wrapper(*args, **kwargs)
82 kernel(grid0, grid1, grid2)
83 else:
---> 84 kernel(grid0, grid1, grid2)
85 else:
86 if self.device_check:

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/triton/compiler/compiler.py:523, in CompiledKernel.getitem..runner(stream, *args)
521 stream = driver.active.get_current_stream(device)
522 launch_metadata = self.launch_metadata(grid, stream, *args)
--> 523 self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
524 knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)

File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/triton/backends/nvidia/driver.py:708, in CudaLauncher.call(self, gridX, gridY, gridZ, stream, function, *args)
706 else:
707 global_scratch = None
--> 708 self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
709 global_scratch, *args)

TypeError: function takes exactly 27 arguments (20 given)

I am trying to execute the official notebook, but I am stuck in this error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions