Skip to content

Commit d757334

Browse files
dfmKfacJaxDev
authored andcommitted
[JAX] Fix Jaxpr comparisons for upcoming change to DropVar behavior.
JAX will be removing `DropVar` annotations on unused variables from Jaxprs since these annotations are typically only used for pretty printing. This change updates the logic in the graph matching tests to support this upcoming change. PiperOrigin-RevId: 750527617
1 parent 486f896 commit d757334

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/test_graph_matcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn):
3535
"""Checks that equation is matched in the other graph."""
3636

3737
eqn1_out_vars = [v for v in eqn1.outvars
38-
if not isinstance(v, jax.core.DropVar)]
38+
if not isinstance(v, jax.core.DropVar) and
39+
v in vars_to_vars]
3940
eqn2_out_vars = [vars_to_vars[v] for v in eqn1_out_vars]
4041
eqns = [vars_to_eqn[v] for v in eqn2_out_vars]
4142
self.assertTrue(all(e == eqns[0] for e in eqns[1:]))

0 commit comments

Comments
 (0)