Skip to content

RuntimeError: Inference tensors cannot be saved for backward (with CUDA kernels installed) #169

@leoauri

Description

@leoauri

I was able to start training with the naive implementation, but having built/installed the CUDA kernels, it blows up like this:

Training:   0%|          | 0/12330 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/12330 [00:00<?, ?it/s] 
Traceback (most recent call last):
  File "/home/rave-sashimi/.venv/bin/rave", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/rave-sashimi/RAVE/scripts/main_cli.py", line 30, in main
    app.run(train.main)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/rave-sashimi/RAVE/scripts/train.py", line 279, in main
    trainer.fit(model, train, val, ckpt_path=run)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
    self._run_train()
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1205, in _run_train
    self.fit_loop.run()
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance
    batch_output = self.batch_loop.run(kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 90, in advance
    outputs = self.manual_loop.run(kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step
    return self.model.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/RAVE/rave/model.py", line 306, in training_step
    y = self.decoder(z)
        ^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/RAVE/rave/sashimi_decoder.py", line 138, in forward
    x, _ = layer(x)
           ^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/src/models/sequence/backbones/block.py", line 105, in forward
    y_for, new_state = self.layer(y, state=state, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/src/models/sequence/modules/s4block.py", line 168, in forward
    y, state = self.layer(x, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/src/models/sequence/kernels/fftconv.py", line 93, in forward
    k, k_state =  self.kernel(L=l_kernel, rate=rate, state=state) # (C H L) (B C H L)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/src/models/sequence/kernels/ssm.py", line 846, in forward
    r = cauchy_mult(v, z, A)
        ^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/extensions/kernels/cauchy.py", line 76, in cauchy_mult
    y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/s4/s4/extensions/kernels/cauchy.py", line 59, in _cauchy_mult
    return CauchyMultiplySymmetric.apply(v, z, w)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rave-sashimi/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.

Epoch 0:   0%|          | 0/12330 [00:19<?, ?it/s]

Any ideas?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions