Using a custom PyTorch backward pass with Lightning #18699
thibmonsel
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm currently using a custom
backward
withtorch.autograd.Function
and I can't seem to find anything that shows how to integrate it with Lightning. For a custombackward
pass we could for example think of the one from Neural ODE's (Algorithm 1).What I have is something of the sort :
For a regular Pytorch training I would call it as something of the following :
How would I integrate this in a LightningModule ? As far as I know, the
training_step
andbackward
methods won't do the trick ? Maybe I'm not thinking of other hooks !Another possibility would be to use Legendre example.
Thanks in advance !
Beta Was this translation helpful? Give feedback.
All reactions