-
Notifications
You must be signed in to change notification settings - Fork 26
Open
Description
Hello,
When reading the implementation, I noticed that in the forward-backward pass, you used a dot-product before running the backward pass, specifically in the following line:
GradCache/src/grad_cache/grad_cache.py
Line 241 in 0c33638
surrogate = torch.dot(reps.flatten(), gradient.flatten()) |
I can't understand this, when reading the paper I imagined that you would directly use the gradients cached, something like:
reps.backward(gradient=gradients)
How exactly does the "surrogate" work to utilise the cached gradient? and why wouldn't the "standard" way of doing it work?
Thanks.
Metadata
Metadata
Assignees
Labels
No labels