@@ -246,24 +246,34 @@ def __call__(
246
246
return gm .forward
247
247
248
248
249
- # Equivalent to backend="aot_eager", but also records graphs that
250
- # we can assert on
251
- class AOTEagerAndRecordGraphs :
249
+ class AotEagerAndRecordGraphs :
252
250
def __init__ (self ) -> None :
253
251
self .graphs : List [torch .fx .GraphModule ] = []
252
+ self .fw_graphs : List [torch .fx .GraphModule ] = []
253
+ self .bw_graphs : List [torch .fx .GraphModule ] = []
254
254
255
255
def __call__ (
256
256
self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
257
257
) -> Callable [..., Any ]:
258
- def save_graph (gm : torch .fx .GraphModule , * args : Any , ** kwargs : Any ) -> Any :
259
- self .graphs .append (gm )
258
+ self .graphs .append (gm )
259
+
260
+ def fw_compiler (
261
+ gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
262
+ ) -> Callable [..., Any ]:
263
+ self .fw_graphs .append (gm )
264
+ return gm .forward
265
+
266
+ def bw_compiler (
267
+ gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
268
+ ) -> Callable [..., Any ]:
269
+ self .bw_graphs .append (gm )
260
270
return gm .forward
261
271
262
272
return aot_eager (
263
273
gm ,
264
274
example_inputs ,
265
- fw_compiler = save_graph ,
266
- bw_compiler = save_graph ,
275
+ fw_compiler = fw_compiler ,
276
+ bw_compiler = bw_compiler ,
267
277
)
268
278
269
279
0 commit comments