Skip to content

Add Option to Print/Save Intermediate Node Values in Nx.Defn.Evaluator #1606

Closed
@polvalente

Description

@polvalente

We would like to add a set of options to Nx.Defn.Evaluator that allow users to either print to the screen or save to disk the values of each unique Nx.Defn.Expr node during evaluation. This feature should ensure that each node is only printed/saved once, even if it is cached and reused multiple times in the computation graph.

Motivation

Debugging and understanding complex computation graphs in Nx can be challenging, especially when expressions are reused and cached. Having the ability to inspect the operation, arguments, and resulting value of each unique node—either by printing them or saving them to disk—would greatly aid in debugging, profiling, and educational use cases.

Desired Behavior

We will add the :debug_options key as a Nx.Defn compiler option.
It has the following inner options:

  • :inspect_limit (integer): Limit the number of elements shown in the inspected result (passed to inspect/2).
  • :save_path (string): Directory path to save the inspected results (if :save is chosen). If not provided, the results will be printed to stdout.

Each node (identified by its unique id in the computation graph) should only be printed/saved once, even if it is cached and reused.

  • Format:
    • Text representation: Output should be similar to the following:
      [Nx.BinaryBackend] 
      Node ID: <ref>
      Operation: :add
      Args: [
        <arg0 id>: #Nx.Tensor<...>,
        <arg1 id>: #Nx.Tensor<...>
      ]
      Result: #Nx.Tensor<
        f32[2][2]
        1.0 2.0
        3.0 4.0
      >
      
    • For saving: Each node's info should be saved as a separate file (e.g., node_123.txt) in the specified directory.

Example Usage

Nx.Defn.Evaluator.__compile__(..., debug_options: [save_path: "/tmp/nx_nodes", inspect_limit: integer | :infinity)

Implementation Notes

  • The logic should hook into the evaluation step, after a node's value is computed and before it is cached.
  • A set or map should be used to track which node ids have already been printed/saved.
  • The feature should be opt-in and have minimal performance impact when disabled.

Additional Context

This feature would be especially useful for:

  • Debugging complex models
  • Understanding how expressions are evaluated and cached
  • Teaching and demonstrations

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions