Skip to content

Fix useless instantiation of zeros #4

@xalelax

Description

@xalelax

This line is problematic. We are instantiating zeros and then don't do anything with them. Why, you might ask 🙂

Let's start with the basics: jax has a funny way of keeping track of which arguments are traced in multi-input functions. Have a look at it yourself: in test_grad, when argnums is either 0 or 1, if you look at which primals and tangents are sent to, say, the tesseract_dispatch_jvp_rule function, you will see that only one argument is non-static (i.e., len(in_args)==1 and len(tan_args)==1), even if the tesseract we are using in the test accepts multiple inputs.

However, when the gradient of a jitted function is calculated, like we do in test_grad_jit, even if we try to do the same thing as above we can see that we are passing around two primals and two tangents. Here's what my debugger shows inside tesseract_dispatch_jvp_rule for argnums=2:

>>> tan_args
(Zero(ShapedArray(float32[2])), Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=3/0)>)

i.e.: we have a tuple of two vectors -- one of which, the one we are not using in differentiation, marked as a symbolic zero.

We could make use of this info when we build jvp_inputs in the TesseractClient.jacobian_vector_product... but we don't.
Even if Tesseracts would support calculating jvp only w.r.t. some inputs, this forces the client to always ask for derivatives w.r.t. everything.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions