Skip to content

Add Option to Print/Save Intermediate Node Values in Nx.Defn.Evaluator #1606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
polvalente opened this issue May 6, 2025 · 0 comments
Open
Assignees

Comments

@polvalente
Copy link
Contributor

We would like to add a set of options to Nx.Defn.Evaluator that allow users to either print to the screen or save to disk the values of each unique Nx.Defn.Expr node during evaluation. This feature should ensure that each node is only printed/saved once, even if it is cached and reused multiple times in the computation graph.

Motivation

Debugging and understanding complex computation graphs in Nx can be challenging, especially when expressions are reused and cached. Having the ability to inspect the operation, arguments, and resulting value of each unique node—either by printing them or saving them to disk—would greatly aid in debugging, profiling, and educational use cases.

Desired Behavior

We will add the :debug_options key as a Nx.Defn compiler option.
It has the following inner options:

  • :inspect_limit (integer): Limit the number of elements shown in the inspected result (passed to inspect/2).
  • :save_path (string): Directory path to save the inspected results (if :save is chosen). If not provided, the results will be printed to stdout.

Each node (identified by its unique id in the computation graph) should only be printed/saved once, even if it is cached and reused.

  • Format:
    • Text representation: Output should be similar to the following:
      [Nx.BinaryBackend] 
      Node ID: <ref>
      Operation: :add
      Args: [
        <arg0 id>: #Nx.Tensor<...>,
        <arg1 id>: #Nx.Tensor<...>
      ]
      Result: #Nx.Tensor<
        f32[2][2]
        1.0 2.0
        3.0 4.0
      >
      
    • For saving: Each node's info should be saved as a separate file (e.g., node_123.txt) in the specified directory.

Example Usage

Nx.Defn.Evaluator.__compile__(..., debug_options: [save_path: "/tmp/nx_nodes", inspect_limit: integer | :infinity)

Implementation Notes

  • The logic should hook into the evaluation step, after a node's value is computed and before it is cached.
  • A set or map should be used to track which node ids have already been printed/saved.
  • The feature should be opt-in and have minimal performance impact when disabled.

Additional Context

This feature would be especially useful for:

  • Debugging complex models
  • Understanding how expressions are evaluated and cached
  • Teaching and demonstrations
@polvalente polvalente self-assigned this May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant