-
Couldn't load subscription status.
- Fork 1
Description
NB This is more of a record of what seem to be inherent limitations of encapsulating expensive nonlinear black-box functions in jax primitives. I don't currently see any practical way to address the below problem in a general way given the current design of jax. Nevertheless, I think it is probably worth having this issue live for awareness and tracking.
User story
I really like that tesseract-jax means I no longer have to re-write and re-build my Tesseract every time I want to (nonlinearly) change my loss function and can instead have my loss function implemented as a jax function in my Python workflow script. However, when doing so I incur a performance hit.
Roots of the problem
- Computing the gradient of any nonlinear loss function
Lapplied to the outputs of a TesseractT(x)requires invocations of both thevector_jacobian_product(in calculating the cotangentT'(x)) and theapply(in calculating the "primal"L'(T(x))) endpoints. (For a linear loss function, theapplycall is not required and can be optimised away.) - If the Tesseract
applyfunction itself includes any nonlinear operations, calculating thevector_jacobian_productwould (usually) require the same primal operations involved inapplyuntil the last nonlinear operation when storing residuals on the forward pass for re-use on the reverse pass (these cannot be compiled away as they depend on the value of the primals and not just their abstract form). Therefore, it is possible to calculate the outputs ofapply_and_vjpmore efficiently than two separate serial calls toapplyandvector_jacobian_productif the cotangent vectors are already known. - However, such a function could not possibly be used in a generally extensive reverse-mode pipeline as the primal output is typically required before the values of the cotangent vectors are known. (Nevertheless, if it's internals are fully exposed under
jitsimilar efficiency savings are still possible within memory constraints as single forward pass will be performed.)
Demonstration
(I have a notebook for the below but github won't let me upload.)
I've made an example with diffrax using 2^20 timesteps of the dimensionless plasma diffusion equation (where the thermal conductivity scales at Te ** 2/5).
import equinox as eqx
import jax
import jax.numpy as jnp
import diffrax
jax.config.update("jax_enable_x64", True)
class Stepper(diffrax.Euler):
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
args["dt"] = t1 - t0
y1 = terms.vf(t0, y0, args)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, diffrax.RESULTS.successful
@eqx.filter_jit
def conduct_plasma(t, y, args):
k_th = args["dt"] * y ** 2.5 / 3.5 / args["dx"]**2
d = 1+ 2.0 * k_th
dl = jnp.zeros_like(d).at[1:].set(-k_th[:-1])
d = d.at[0].add(dl[1])
du = jnp.zeros_like(d).at[:-1].set(-k_th[1:])
d = d.at[-1].add(du[-2])
rhs= y[:, None]
y_new = jax.lax.linalg.tridiagonal_solve(dl, d, du, rhs)
return y_new[:, 0]
t1 = 0.1
nt = 2 ** 20
dt = t1 / nt
nx = 1000
dx = 2.0 / nx
Te = 2. + jnp.tanh(5. * jnp.arange(-1 + dx / 2, 1, dx))
@eqx.filter_jit
def apply_wrapped(Te_init):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(vector_field),
Stepper(),
t0=0.0, t1=t1, dt0=dt,
max_steps = nt, y0=Te_init,
saveat=diffrax.SaveAt(t0=False, t1=True),
args={"dx": dx, "dt": dt}
)
return sol.ys[-1]Timings (after ensuring jit)
- 0m24s Calculating the jitted
apply_wrappedfunction - 2m05s Calculating jitted vjp only (as in
vjp_jitin jax recipe) - 2m04s Calculating apply and jitted vjp (retaining both values)
- 2m08s Calculating gradient of sum of squares of final temperature
We observe that all three reverse-mode functions have very similar runtime and differ by much less than the the runtime of the forward pass. This is because the vjp requires (checkpointed) calculations of the entire simulation except for the final timestep and the extra computational power required for the final timestep is negligible.
If we tried to calculate the final bullet point using tesseract-jax it would take ~2m28s due to separate apply and vector_jacobian_product calls.
Implications
Note that for this example, the apply call is ~20% the cost vector_jacobian_product. However, typically the relative cost can be as high as ~33% (perhaps more in some cases) meaning that the inability to utilise a jitted apply_and_vjp endpoint can lead to performance hits as high as 33%.
Partial/imperfect solutions
One saving grace here is that setting num_workers to 2 or higher should hopefully enable multi-threaded calls to the two endpoints, meaning the performance hit would only be in terms of comp time rather than user time.
The only real "solution" to this problem seems to be the one presented here where a custom_vjp is defined using two ffi calls, the first of which computes the forward pass and returns all required residuals. Going down this route is probably not worth exploring at this juncture as the memory/bandwidth requirement for complex functions would likely be very high (unless there is a way to checkpoint), but hypothetically the forward pass could simply be an overload of apply instead of an additional endpoint (perhaps even including residuals as default if added expense is minimal).
One option which would only have a chance of being somewhat efficient for low-dimensional outputs (and very computationally expensive for high-dimensional outputs) is an apply_and_jacrev endpoint.
@jax.custom_vjp
def apply_tesseract_rev(inputs: dict):
return apply_tesseract(t, inputs)
def fwd(inputs):
apply, jacrev = apply_and_jacrev_jitted(t, inputs)
return apply, jacrev
def bwd(res, ct):
return jnp.dot(jacrev, ct)