Skip to content

Commit 43039cd

Browse files
committed
tests(//py): Python tests
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 482ef2c commit 43039cd

File tree

4 files changed

+118
-9
lines changed

4 files changed

+118
-9
lines changed

WORKSPACE

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,17 @@ workspace(name = "TRTorch")
33
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
44
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
55

6-
7-
8-
9-
http_archive(
6+
git_repository(
107
name = "rules_python",
11-
url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.1/rules_python-0.0.1.tar.gz",
12-
sha256 = "aa96a691d3a8177f3215b14b0edc9641787abaaa30363a080165d06ab65e1161",
8+
remote = "https://github.com/bazelbuild/rules_python.git",
9+
commit = "4fcc24fd8a850bdab2ef2e078b1de337eea751a6",
10+
shallow_since = "1589292086 -0400"
1311
)
1412

1513
load("@rules_python//python:repositories.bzl", "py_repositories")
1614
py_repositories()
17-
# Only needed if using the packaging rules.
18-
load("@rules_python//python:pip.bzl", "pip_repositories", "pip_import")
15+
16+
load("@rules_python//python:pip.bzl", "pip_repositories", "pip3_import")
1917
pip_repositories()
2018

2119
http_archive(
@@ -35,7 +33,7 @@ new_local_repository(
3533
)
3634

3735
http_archive(
38-
name = "libtorch_non_cxx11_abi",
36+
name = "libtorch_pre_cxx11_abi",
3937
build_file = "@//third_party/libtorch:BUILD",
4038
strip_prefix = "libtorch",
4139
sha256 = "ea8de17c5f70015583f3a7a43c7a5cdf91a1d4bd19a6a7bc11f074ef6cd69e27",
@@ -50,6 +48,22 @@ http_archive(
5048
sha256 = "0efdd4e709ab11088fa75f0501c19b0e294404231442bab1d1fb953924feb6b5"
5149
)
5250

51+
pip3_import(
52+
name = "trtorch_py_deps",
53+
requirements = "//py:requirements.txt"
54+
)
55+
56+
load("@trtorch_py_deps//:requirements.bzl", "pip_install")
57+
pip_install()
58+
59+
pip3_import(
60+
name = "py_test_deps",
61+
requirements = "//tests/py:requirements.txt"
62+
)
63+
64+
load("@py_test_deps//:requirements.bzl", "pip_install")
65+
pip_install()
66+
5367
# Downloaded distributions to use with --distdir
5468
http_archive(
5569
name = "cudnn",

tests/py/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package(default_visibility = ["//visibility:public"])
2+
load("@py_test_deps//:requirements.bzl", "requirement")
3+
4+
5+
py_test(
6+
name = "test_api",
7+
srcs = [
8+
"test_api.py"
9+
],
10+
deps = [
11+
requirement("torchvision")
12+
]
13+
)

tests/py/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchvision==0.6.0

tests/py/test_api.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
6+
7+
class ModelTestCase(unittest.TestCase):
8+
def __init__(self, methodName='runTest', model=None):
9+
super(ModelTestCase, self).__init__(methodName)
10+
self.model = model
11+
self.model.eval().to("cuda")
12+
13+
@staticmethod
14+
def parametrize(testcase_class, model=None):
15+
testloader = unittest.TestLoader()
16+
testnames = testloader.getTestCaseNames(testcase_class)
17+
suite = unittest.TestSuite()
18+
for name in testnames:
19+
suite.addTest(testcase_class(name, model=model))
20+
return suite
21+
22+
class TestCompile(ModelTestCase):
23+
def setUp(self):
24+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
25+
self.traced_model = torch.jit.trace(self.model, [self.input])
26+
self.scripted_model = torch.jit.script(self.model)
27+
28+
def test_compile_traced(self):
29+
extra_info = {
30+
"input_shapes": [self.input.shape],
31+
}
32+
33+
trt_mod = trtorch.compile(self.traced_model, extra_info)
34+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
35+
self.assertTrue(same < 2e-3)
36+
37+
#def test_compile_script(self):
38+
# pass
39+
40+
class TestCheckMethodOpSupport(unittest.TestCase):
41+
def setUp(self):
42+
module = models.alexnet(pretrained=True).eval().to("cuda")
43+
self.module = torch.jit.trace(module, torch.ones((1, 3, 224, 224)).to("cuda"))
44+
45+
def test_check_support(self):
46+
self.assertTrue(trtorch.check_method_op_support(self.module, "forward"))
47+
48+
class TestLoggingAPIs(unittest.TestCase):
49+
def test_logging_prefix(self):
50+
new_prefix = "TEST"
51+
trtorch.logging.set_logging_prefix(new_prefix)
52+
logging_prefix = trtorch.logging.get_logging_prefix()
53+
self.assertEqual(new_prefix, logging_prefix)
54+
55+
def test_reportable_log_level(self):
56+
new_level = trtorch.logging.Level.Warning
57+
trtorch.logging.set_reportable_log_level(new_level)
58+
level = trtorch.logging.get_reportable_log_level()
59+
self.assertEqual(new_level, level)
60+
61+
def test_is_colored_output_on(self):
62+
trtorch.logging.set_is_colored_output_on(True)
63+
color = trtorch.logging.get_is_colored_output_on()
64+
self.assertTrue(color)
65+
66+
def test_suite():
67+
suite = unittest.TestSuite()
68+
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
69+
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet50(pretrained=True)))
70+
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
71+
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
72+
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
73+
74+
return suite
75+
76+
suite = test_suite()
77+
78+
runner = unittest.TextTestRunner()
79+
result = runner.run(suite)
80+
81+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)