Skip to content

Commit 7d2e33f

Browse files
committed
chore: fix some tests
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 7be898f commit 7d2e33f

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,15 @@ def refit_module_weights(
274274
else:
275275
for name, submodule in compiled_module.named_children():
276276
if not isinstance(
277-
submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
277+
submodule,
278+
(
279+
PythonTorchTensorRTModule,
280+
TorchTensorRTModule,
281+
torch.nn.modules.module.Module,
282+
),
278283
):
279284
continue
285+
280286
settings = submodule.settings
281287

282288
assert settings is not None

tests/py/dynamo/models/test_model_refit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ def test_refit_one_engine_inline_runtime_with_weightmap():
540540
min_block_size = 1
541541
use_python_runtime = False
542542

543-
exp_program = torch.export.export(model, tuple(inputs))
544-
exp_program2 = torch.export.export(model2, tuple(inputs))
543+
exp_program = torch.export.export(model, tuple(inputs), strict=False)
544+
exp_program2 = torch.export.export(model2, tuple(inputs), strict=False)
545545

546546
trt_gm = torchtrt.dynamo.compile(
547547
exp_program,
@@ -551,8 +551,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap():
551551
min_block_size=min_block_size,
552552
immutable_weights=False,
553553
)
554-
torchtrt.save(trt_gm, trt_ep_path)
554+
torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs)
555+
555556
trt_gm = torch.export.load(trt_ep_path)
557+
556558
new_trt_gm = refit_module_weights(
557559
compiled_module=trt_gm,
558560
new_weight_module=exp_program2,
@@ -906,7 +908,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
906908
min_block_size=min_block_size,
907909
immutable_weights=False,
908910
)
909-
torchtrt.save(trt_gm, trt_ep_path)
911+
torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs)
910912
trt_gm = torch.export.load(trt_ep_path)
911913
new_trt_gm = refit_module_weights(
912914
compiled_module=trt_gm,

tests/py/dynamo/runtime/test_002_lazy_engine_init.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ def test_lazy_engine_init_cpp_serialization(self):
314314
trt_mod = torchtrt.compile(model, **compile_spec)
315315

316316
with tempfile.TemporaryDirectory() as tmpdir:
317-
torch_tensorrt.save(trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"))
317+
torch_tensorrt.save(
318+
trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), arg_inputs=(input,)
319+
)
318320
new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep"))
319321

320322
loaded_trt_mod = new_trt_mod.module()

0 commit comments

Comments
 (0)