@@ -56,7 +56,7 @@ def forward(self, x):
5656
5757 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
5858 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
59- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
59+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
6060
6161 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
6262 # Check Pyt and TRT exported program outputs
@@ -111,7 +111,7 @@ def forward(self, x):
111111
112112 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
113113 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
114- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
114+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
115115
116116 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
117117 # Check Pyt and TRT exported program outputs
@@ -170,7 +170,7 @@ def forward(self, x):
170170
171171 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
172172 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
173- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
173+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
174174
175175 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
176176 # Check Pyt and TRT exported program outputs
@@ -232,7 +232,7 @@ def forward(self, x):
232232
233233 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
234234 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
235- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
235+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
236236
237237 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
238238 outputs_pyt = model (input )
@@ -279,7 +279,7 @@ def test_resnet18(ir):
279279
280280 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
281281 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
282- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
282+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
283283
284284 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
285285 outputs_pyt = model (input )
@@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir):
331331 msg = "Model should be offloaded to CPU" ,
332332 )
333333 model .cuda ()
334- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
334+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
335335
336336 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
337337 outputs_pyt = model (input )
@@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir):
380380
381381 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
382382 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
383- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
383+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
384384 # TODO: Enable this serialization issues are fixed
385385 # deser_trt_module = torchtrt.load(trt_ep_path).module()
386386 outputs_pyt = model (input )
@@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir):
413413
414414 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
415415 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
416- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
416+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
417417 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
418418 outputs_pyt = deser_trt_module (input )
419419 outputs_trt = trt_module (input )
@@ -463,7 +463,7 @@ def forward(self, x):
463463 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
464464 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
465465
466- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
466+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
467467
468468 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
469469 outputs_pyt = model (input )
@@ -525,7 +525,7 @@ def forward(self, x):
525525 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
526526 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
527527 model .cuda ()
528- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
528+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
529529
530530 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
531531 outputs_pyt = model (input )
@@ -584,7 +584,7 @@ def forward(self, x):
584584 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
585585 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
586586
587- torchtrt .save (trt_module , trt_ep_path , arg_inputs = [ input ], retrace = False )
587+ torchtrt .save (trt_module , trt_ep_path , retrace = False )
588588
589589 deser_trt_module = torchtrt .load (trt_ep_path ).module ()
590590 outputs_pyt = model (input )
0 commit comments