-
Notifications
You must be signed in to change notification settings - Fork 1
Description
This line is problematic. We are instantiating zeros and then don't do anything with them. Why, you might ask 🙂
Let's start with the basics: jax has a funny way of keeping track of which arguments are traced in multi-input functions. Have a look at it yourself: in test_grad, when argnums is either 0 or 1, if you look at which primals and tangents are sent to, say, the tesseract_dispatch_jvp_rule function, you will see that only one argument is non-static (i.e., len(in_args)==1 and len(tan_args)==1), even if the tesseract we are using in the test accepts multiple inputs.
However, when the gradient of a jitted function is calculated, like we do in test_grad_jit, even if we try to do the same thing as above we can see that we are passing around two primals and two tangents. Here's what my debugger shows inside tesseract_dispatch_jvp_rule for argnums=2:
>>> tan_args
(Zero(ShapedArray(float32[2])), Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=3/0)>)i.e.: we have a tuple of two vectors -- one of which, the one we are not using in differentiation, marked as a symbolic zero.
We could make use of this info when we build jvp_inputs in the TesseractClient.jacobian_vector_product... but we don't.
Even if Tesseracts would support calculating jvp only w.r.t. some inputs, this forces the client to always ask for derivatives w.r.t. everything.