Skip to content

Commit 7be898f

Browse files
committed
chore: fix tests
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent cc83065 commit 7be898f

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x, b=5, c=None, d=None):
7676

7777
# Save the module
7878
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
79-
torchtrt.save(trt_gm, trt_ep_path)
79+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
8080
# Clean up model env
8181
torch._dynamo.reset()
8282

@@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None):
138138

139139
# Save the module
140140
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
141-
torchtrt.save(trt_gm, trt_ep_path)
141+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
142142
# Clean up model env
143143
torch._dynamo.reset()
144144

@@ -209,7 +209,7 @@ def forward(self, x, b=5, c=None, d=None):
209209

210210
# Save the module
211211
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
212-
torchtrt.save(trt_gm, trt_ep_path)
212+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
213213
# Clean up model env
214214
torch._dynamo.reset()
215215

@@ -299,7 +299,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
299299
)
300300
# Save the module
301301
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
302-
torchtrt.save(trt_gm, trt_ep_path)
302+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
303303
# Clean up model env
304304
torch._dynamo.reset()
305305

@@ -389,7 +389,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
389389
)
390390
# Save the module
391391
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
392-
torchtrt.save(trt_gm, trt_ep_path)
392+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
393393
# Clean up model env
394394
torch._dynamo.reset()
395395

tests/py/dynamo/models/test_export_serde.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, x):
5656

5757
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
5858
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
59-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
59+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
6060

6161
deser_trt_module = torchtrt.load(trt_ep_path).module()
6262
# Check Pyt and TRT exported program outputs
@@ -111,7 +111,7 @@ def forward(self, x):
111111

112112
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
113113
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
114-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
114+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
115115

116116
deser_trt_module = torchtrt.load(trt_ep_path).module()
117117
# Check Pyt and TRT exported program outputs
@@ -170,7 +170,7 @@ def forward(self, x):
170170

171171
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
172172
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
173-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
173+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
174174

175175
deser_trt_module = torchtrt.load(trt_ep_path).module()
176176
# Check Pyt and TRT exported program outputs
@@ -232,7 +232,7 @@ def forward(self, x):
232232

233233
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
234234
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
235-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
235+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
236236

237237
deser_trt_module = torchtrt.load(trt_ep_path).module()
238238
outputs_pyt = model(input)
@@ -279,7 +279,7 @@ def test_resnet18(ir):
279279

280280
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
281281
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
282-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
282+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
283283

284284
deser_trt_module = torchtrt.load(trt_ep_path).module()
285285
outputs_pyt = model(input)
@@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir):
331331
msg="Model should be offloaded to CPU",
332332
)
333333
model.cuda()
334-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
334+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
335335

336336
deser_trt_module = torchtrt.load(trt_ep_path).module()
337337
outputs_pyt = model(input)
@@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir):
380380

381381
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
382382
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
383-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
383+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
384384
# TODO: Enable this serialization issues are fixed
385385
# deser_trt_module = torchtrt.load(trt_ep_path).module()
386386
outputs_pyt = model(input)
@@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir):
413413

414414
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
415415
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
416-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
416+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
417417
deser_trt_module = torchtrt.load(trt_ep_path).module()
418418
outputs_pyt = deser_trt_module(input)
419419
outputs_trt = trt_module(input)
@@ -463,7 +463,7 @@ def forward(self, x):
463463
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
464464
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
465465

466-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
466+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
467467

468468
deser_trt_module = torchtrt.load(trt_ep_path).module()
469469
outputs_pyt = model(input)
@@ -525,7 +525,7 @@ def forward(self, x):
525525
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
526526
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
527527
model.cuda()
528-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
528+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
529529

530530
deser_trt_module = torchtrt.load(trt_ep_path).module()
531531
outputs_pyt = model(input)
@@ -584,7 +584,7 @@ def forward(self, x):
584584
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
585585
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
586586

587-
torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False)
587+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
588588

589589
deser_trt_module = torchtrt.load(trt_ep_path).module()
590590
outputs_pyt = model(input)

0 commit comments

Comments
 (0)