-
Notifications
You must be signed in to change notification settings - Fork 8
Description
TypeError Traceback (most recent call last)
Cell In[8], line 4
1 for batch in train_loader:
2 x = batch[0].to(device)
----> 4 lls = pc(x, record_cudagraph = True)
5 lls.mean().backward()
6 break
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/model/tensorcircuit.py:228, in TensorCircuit.forward(self, inputs, input_layer_fn, cache, return_cache, record_cudagraph, apply_cudagraph, force_use_bf16, force_use_fp32, propagation_alg, _inner_layers_only, _no_buffer_reset, **kwargs)
226 for idx, layer in enumerate(self.input_layer_group):
227 if input_layer_fn is None:
--> 228 layer(inputs, self.node_mars, **kwargs)
230 elif isinstance(input_layer_fn, str):
231 assert hasattr(layer, input_layer_fn), f"Custom input function {input_layer_fn}
not found for layer type {type(layer)}."
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/layer/input_layer.py:274, in InputLayer.forward(self, data, node_mars, params, missing_mask, _batch_first, _apply_missing_mask_only, **kwargs)
271 grid = (triton.cdiv(layer_num_nodes * batch_size, BLOCK_SIZE),)
273 if not _apply_missing_mask_only:
--> 274 self._mars_kernel[grid](
275 params_ptr = self.params,
276 node_mars_ptr = node_mars,
277 data_ptr = data,
278 vids_ptr = self.vids,
279 s_pids_ptr = self.s_pids,
280 metadata_ptr = self.metadata,
281 s_mids_ptr = self.s_mids,
282 fw_local_ids_ptr = fw_local_ids,
283 layer_num_nodes = layer_num_nodes,
284 batch_size = batch_size,
285 num_vars_per_node = self.num_vars_per_node,
286 nv_block_size = triton.next_power_of_2(self.num_vars_per_node),
287 node_offset = node_offset,
288 BLOCK_SIZE = BLOCK_SIZE,
289 partial_eval = 1 if fw_local_ids is not None else 0,
290 num_warps = 8
291 )
292 else:
293 assert missing_mask is not None, "missing_mask
should be provided when _apply_missing_mask_only = True
."
File /mnt/c/Users/tkusumoto/pyjuice/src/pyjuice/utils/kernel_launcher.py:84, in FastJITFunction.getitem..wrapper(*args, **kwargs)
82 kernel(grid0, grid1, grid2)
83 else:
---> 84 kernel(grid0, grid1, grid2)
85 else:
86 if self.device_check:
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/triton/compiler/compiler.py:523, in CompiledKernel.getitem..runner(stream, *args)
521 stream = driver.active.get_current_stream(device)
522 launch_metadata = self.launch_metadata(grid, stream, *args)
--> 523 self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
524 knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
File ~/miniconda3/envs/npc_project/lib/python3.10/site-packages/triton/backends/nvidia/driver.py:708, in CudaLauncher.call(self, gridX, gridY, gridZ, stream, function, *args)
706 else:
707 global_scratch = None
--> 708 self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
709 global_scratch, *args)
TypeError: function takes exactly 27 arguments (20 given)
I am trying to execute the official notebook, but I am stuck in this error.