-
Notifications
You must be signed in to change notification settings - Fork 349
Open
Description
I was able to start training with the naive implementation, but having built/installed the CUDA kernels, it blows up like this:
Training: 0%| | 0/12330 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/12330 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/rave-sashimi/.venv/bin/rave", line 10, in <module>
sys.exit(main())
^^^^^^
File "/home/rave-sashimi/RAVE/scripts/main_cli.py", line 30, in main
app.run(train.main)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/rave-sashimi/RAVE/scripts/train.py", line 279, in main
trainer.fit(model, train, val, ckpt_path=run)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
call._call_and_handle_interrupt(
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
self._run_train()
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1205, in _run_train
self.fit_loop.run()
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance
batch_output = self.batch_loop.run(kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 90, in advance
outputs = self.manual_loop.run(kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step
return self.model.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/RAVE/rave/model.py", line 306, in training_step
y = self.decoder(z)
^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/RAVE/rave/sashimi_decoder.py", line 138, in forward
x, _ = layer(x)
^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/src/models/sequence/backbones/block.py", line 105, in forward
y_for, new_state = self.layer(y, state=state, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/src/models/sequence/modules/s4block.py", line 168, in forward
y, state = self.layer(x, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/src/models/sequence/kernels/fftconv.py", line 93, in forward
k, k_state = self.kernel(L=l_kernel, rate=rate, state=state) # (C H L) (B C H L)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/src/models/sequence/kernels/ssm.py", line 846, in forward
r = cauchy_mult(v, z, A)
^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/extensions/kernels/cauchy.py", line 76, in cauchy_mult
y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/s4/s4/extensions/kernels/cauchy.py", line 59, in _cauchy_mult
return CauchyMultiplySymmetric.apply(v, z, w)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
Epoch 0: 0%| | 0/12330 [00:19<?, ?it/s]
Any ideas?
Metadata
Metadata
Assignees
Labels
No labels