1
+ import unittest
2
+ import trtorch
3
+ import torch
4
+ import torchvision .models as models
5
+
6
+
7
+ class ModelTestCase (unittest .TestCase ):
8
+ def __init__ (self , methodName = 'runTest' , model = None ):
9
+ super (ModelTestCase , self ).__init__ (methodName )
10
+ self .model = model
11
+ self .model .eval ().to ("cuda" )
12
+
13
+ @staticmethod
14
+ def parametrize (testcase_class , model = None ):
15
+ testloader = unittest .TestLoader ()
16
+ testnames = testloader .getTestCaseNames (testcase_class )
17
+ suite = unittest .TestSuite ()
18
+ for name in testnames :
19
+ suite .addTest (testcase_class (name , model = model ))
20
+ return suite
21
+
22
+ class TestCompile (ModelTestCase ):
23
+ def setUp (self ):
24
+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
25
+ self .traced_model = torch .jit .trace (self .model , [self .input ])
26
+ self .scripted_model = torch .jit .script (self .model )
27
+
28
+ def test_compile_traced (self ):
29
+ extra_info = {
30
+ "input_shapes" : [self .input .shape ],
31
+ }
32
+
33
+ trt_mod = trtorch .compile (self .traced_model , extra_info )
34
+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
35
+ self .assertTrue (same < 2e-3 )
36
+
37
+ #def test_compile_script(self):
38
+ # pass
39
+
40
+ class TestCheckMethodOpSupport (unittest .TestCase ):
41
+ def setUp (self ):
42
+ module = models .alexnet (pretrained = True ).eval ().to ("cuda" )
43
+ self .module = torch .jit .trace (module , torch .ones ((1 , 3 , 224 , 224 )).to ("cuda" ))
44
+
45
+ def test_check_support (self ):
46
+ self .assertTrue (trtorch .check_method_op_support (self .module , "forward" ))
47
+
48
+ class TestLoggingAPIs (unittest .TestCase ):
49
+ def test_logging_prefix (self ):
50
+ new_prefix = "TEST"
51
+ trtorch .logging .set_logging_prefix (new_prefix )
52
+ logging_prefix = trtorch .logging .get_logging_prefix ()
53
+ self .assertEqual (new_prefix , logging_prefix )
54
+
55
+ def test_reportable_log_level (self ):
56
+ new_level = trtorch .logging .Level .Warning
57
+ trtorch .logging .set_reportable_log_level (new_level )
58
+ level = trtorch .logging .get_reportable_log_level ()
59
+ self .assertEqual (new_level , level )
60
+
61
+ def test_is_colored_output_on (self ):
62
+ trtorch .logging .set_is_colored_output_on (True )
63
+ color = trtorch .logging .get_is_colored_output_on ()
64
+ self .assertTrue (color )
65
+
66
+ def test_suite ():
67
+ suite = unittest .TestSuite ()
68
+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
69
+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet50 (pretrained = True )))
70
+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
71
+ suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
72
+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
73
+
74
+ return suite
75
+
76
+ suite = test_suite ()
77
+
78
+ runner = unittest .TextTestRunner ()
79
+ result = runner .run (suite )
80
+
81
+ exit (int (not result .wasSuccessful ()))
0 commit comments