Skip to content

Role of dot product operation in forward-backward pass #29

@ahmed-tabib

Description

@ahmed-tabib

Hello,
When reading the implementation, I noticed that in the forward-backward pass, you used a dot-product before running the backward pass, specifically in the following line:

surrogate = torch.dot(reps.flatten(), gradient.flatten())

I can't understand this, when reading the paper I imagined that you would directly use the gradients cached, something like:

reps.backward(gradient=gradients)

How exactly does the "surrogate" work to utilise the cached gradient? and why wouldn't the "standard" way of doing it work?
Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions