Skip to content

Commit ca08e7e

Browse files
committed
Making modules Jit-/Traceable & exportable to ONNX
1 parent 118c46c commit ca08e7e

File tree

4 files changed

+85
-5
lines changed

4 files changed

+85
-5
lines changed

pytorch_toolbelt/modules/encoders/common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,32 +102,33 @@ def forward(self, x: Tensor) -> List[Tensor]: # skipcq: PYL-W0221
102102
return _take(output_features, self._layers)
103103

104104
@property
105+
@torch.jit.unused
105106
def channels(self) -> List[int]:
106107
return self._output_filters
107108

108109
@property
110+
@torch.jit.unused
109111
def strides(self) -> List[int]:
110112
return self._output_strides
111113

112114
@property
115+
@torch.jit.unused
113116
@pytorch_toolbelt_deprecated("This property is deprecated, please use .strides instead.")
114117
def output_strides(self) -> List[int]:
115118
return self.strides
116119

117120
@property
121+
@torch.jit.unused
118122
@pytorch_toolbelt_deprecated("This property is deprecated, please use .channels instead.")
119123
def output_filters(self) -> List[int]:
120124
return self.channels
121125

122-
@property
123-
@pytorch_toolbelt_deprecated("This property is deprecated, please don't use it")
124-
def encoder_layers(self) -> List[nn.Module]:
125-
raise NotImplementedError
126-
126+
@torch.jit.unused
127127
def set_trainable(self, trainable):
128128
for param in self.parameters():
129129
param.requires_grad = bool(trainable)
130130

131+
@torch.jit.unused
131132
def change_input_channels(self, input_channels: int, mode="auto"):
132133
"""
133134
Change number of channels expected in the input tensor. By default,

pytorch_toolbelt/modules/encoders/timm/efficient_net.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def forward(self, x):
8787
features = self.encoder(x)
8888
return _take(features, self._layers)
8989

90+
@torch.jit.unused
9091
def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
9192
self.encoder.conv_stem = make_n_channel_input_conv2d_same(
9293
self.encoder.conv_stem, input_channels, mode, **kwargs

pytorch_toolbelt/modules/encoders/timm/resnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from typing import List
33

4+
import torch
45
from torch import nn
56

67
from .common import GenericTimmEncoder
@@ -29,6 +30,7 @@ def __init__(self, pretrained=True, layers=None, activation=ACT_RELU):
2930
self.layer4 = encoder.body.layer4
3031

3132
@property
33+
@torch.jit.unused
3234
def encoder_layers(self) -> List[nn.Module]:
3335
return [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]
3436

tests/test_onnx.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
import torch
3+
4+
import pytorch_toolbelt.modules.encoders as E
5+
from pytorch_toolbelt.modules.backbone.inceptionv4 import inceptionv4
6+
from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters
7+
from pytorch_toolbelt.modules.encoders import timm
8+
9+
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available")
10+
11+
12+
@pytest.mark.parametrize(
13+
["encoder", "encoder_params"],
14+
[
15+
[timm.B0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
16+
[timm.MixNetXLEncoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
17+
[timm.SKResNet18Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
18+
[timm.SWSLResNeXt101Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
19+
[timm.TimmResnet200D, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
20+
[timm.HRNetW18Encoder, {"pretrained": False}],
21+
[timm.DPN68Encoder, {"pretrained": False}],
22+
[timm.NFNetF0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
23+
[timm.NFNetF0SEncoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
24+
[timm.NFRegNetB0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
25+
[timm.TimmRes2Next50Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
26+
],
27+
)
28+
@skip_if_no_cuda
29+
def test_onnx_export(encoder, encoder_params):
30+
import onnx
31+
32+
model = encoder(**encoder_params).eval()
33+
34+
print(model.__class__.__name__, count_parameters(model))
35+
print(model.strides)
36+
print(model.channels)
37+
dummy_input = torch.rand((1, 3, 256, 256))
38+
dummy_input = maybe_cuda(dummy_input)
39+
model = maybe_cuda(model)
40+
41+
input_names = ["image"]
42+
output_names = [f"feature_map_{i}" for i in range(len(model.channels))]
43+
44+
torch.onnx.export(model, dummy_input, "tmp.onnx", verbose=True, input_names=input_names, output_names=output_names)
45+
model = onnx.load("tmp.onnx")
46+
onnx.checker.check_model(model)
47+
48+
49+
@pytest.mark.parametrize(
50+
["encoder", "encoder_params"],
51+
[
52+
[timm.B0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
53+
[timm.MixNetXLEncoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
54+
[timm.SKResNet18Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
55+
[timm.SWSLResNeXt101Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
56+
[timm.TimmResnet200D, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
57+
[timm.HRNetW18Encoder, {"pretrained": False}],
58+
[timm.DPN68Encoder, {"pretrained": False}],
59+
[timm.NFNetF0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
60+
[timm.NFNetF0SEncoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
61+
[timm.NFRegNetB0Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
62+
[timm.TimmRes2Next50Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}],
63+
],
64+
)
65+
@skip_if_no_cuda
66+
def test_jit_trace(encoder, encoder_params):
67+
model = encoder(**encoder_params).eval()
68+
69+
print(model.__class__.__name__, count_parameters(model))
70+
print(model.strides)
71+
print(model.channels)
72+
dummy_input = torch.rand((1, 3, 256, 256))
73+
dummy_input = maybe_cuda(dummy_input)
74+
model = maybe_cuda(model)
75+
76+
model = torch.jit.trace(model, dummy_input, check_trace=True)

0 commit comments

Comments
 (0)