Skip to content

Offer option of returning traceable apply function rather than just it's outputs. #19

@jpbrodrick89

Description

@jpbrodrick89

Summary

Current apply_tesseract evaluates the apply endpoint at the specified inputs after doing some sanity checks. In some cases it may be convenient to return the function itself without specifying the inputs. This also means the checks won't need to be repeated. For example , something like this:

def gen_traceable_apply(
    tesseract_client: Tesseract
) -> Callable:
    if not isinstance(tesseract_client, Tesseract):
        raise TypeError(
            "The first argument must be a Tesseract object. "
            f"Got {type(tesseract_client)} instead."
        )

    if "abstract_eval" not in tesseract_client.available_endpoints:
        raise ValueError(
            "Given Tesseract object does not support abstract_eval, "
            "which is required for compatibility with JAX."
        )

    client = Jaxeract(tesseract_client)

    def wrapped_apply(inputs):
        flat_args, input_pytreedef = jax.tree.flatten(inputs)
        is_static_mask = tuple(_is_static(arg) for arg in flat_args)
        array_args, static_args = split_args(flat_args, is_static_mask)
    
        # Get abstract values for outputs, so we can unflatten them later
        output_pytreedef, avals = None, None
        avals = client.abstract_eval(
            array_args,
            static_args,
            input_pytreedef,
            output_pytreedef,
            avals,
            is_static_mask,
        )
    
        is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
        flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
        for aval in flat_avals:
            if not is_aval(aval):
                continue
            _check_dtype(aval["dtype"])
    
        flat_avals = tuple(
            jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
            for aval in flat_avals
        )
    
        # Apply the primitive
        out = tesseract_dispatch_p.bind(
            *array_args,
            static_args=static_args,
            input_pytreedef=input_pytreedef,
            output_pytreedef=output_pytreedef,
            output_avals=flat_avals,
            is_static_mask=is_static_mask,
            client=client,
            eval_func="apply",
        )
    
        # Unflatten the output
        return jax.tree.unflatten(output_pytreedef, out)

    return wrapped_apply

Why is this needed?

If one wants to calculate the gradient of apply_tesseract one has to re-wrap it:

# This is not possible as the output `apply_tesseract` is just a PyTree
eqx.filter_grad(t.apply_tesseract(t, inputs))

# Instead we have to do this
eqx.filter_grad(lambda inputs: t.apply_teseract(t.inputs))

Usage example

I would prefer the following workflow:

eqx.filter_grad(gen_traceable_apply(t))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions