-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Open
Copy link
Description
Summary
Current apply_tesseract evaluates the apply endpoint at the specified inputs after doing some sanity checks. In some cases it may be convenient to return the function itself without specifying the inputs. This also means the checks won't need to be repeated. For example , something like this:
def gen_traceable_apply(
tesseract_client: Tesseract
) -> Callable:
if not isinstance(tesseract_client, Tesseract):
raise TypeError(
"The first argument must be a Tesseract object. "
f"Got {type(tesseract_client)} instead."
)
if "abstract_eval" not in tesseract_client.available_endpoints:
raise ValueError(
"Given Tesseract object does not support abstract_eval, "
"which is required for compatibility with JAX."
)
client = Jaxeract(tesseract_client)
def wrapped_apply(inputs):
flat_args, input_pytreedef = jax.tree.flatten(inputs)
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
array_args, static_args = split_args(flat_args, is_static_mask)
# Get abstract values for outputs, so we can unflatten them later
output_pytreedef, avals = None, None
avals = client.abstract_eval(
array_args,
static_args,
input_pytreedef,
output_pytreedef,
avals,
is_static_mask,
)
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
for aval in flat_avals:
if not is_aval(aval):
continue
_check_dtype(aval["dtype"])
flat_avals = tuple(
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
for aval in flat_avals
)
# Apply the primitive
out = tesseract_dispatch_p.bind(
*array_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=output_pytreedef,
output_avals=flat_avals,
is_static_mask=is_static_mask,
client=client,
eval_func="apply",
)
# Unflatten the output
return jax.tree.unflatten(output_pytreedef, out)
return wrapped_applyWhy is this needed?
If one wants to calculate the gradient of apply_tesseract one has to re-wrap it:
# This is not possible as the output `apply_tesseract` is just a PyTree
eqx.filter_grad(t.apply_tesseract(t, inputs))
# Instead we have to do this
eqx.filter_grad(lambda inputs: t.apply_teseract(t.inputs))Usage example
I would prefer the following workflow:
eqx.filter_grad(gen_traceable_apply(t))Metadata
Metadata
Assignees
Labels
No labels