1
+ #include " torch/torch.h"
1
2
#include " core/util/prelude.h"
2
3
#include " core/conversion/converters/converters.h"
3
4
@@ -8,93 +9,59 @@ namespace converters {
8
9
namespace impl {
9
10
namespace {
10
11
11
- bool ConvertConvBatchNorm (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12
- auto input = args[0 ].ITensor ();
13
- auto shape = util::toVec (input->getDimensions ());
14
- LOG_WARNING (" Assuming channel dimension is 3 because input is from a conv layer, please verify" );
15
- auto gamma = args[1 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 1 ));
16
- auto beta = args[2 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 1 ));
17
- auto mean = args[3 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 0 ));
18
- auto var = args[4 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 0 ));
19
- LOG_WARNING (" Momentum argument is disregarded" );
20
- // auto momentum = args[6].unwrapToDouble(0);
21
- auto eps = args[7 ].unwrapToDouble (0 );
22
-
23
- auto w = at::diag (gamma / at::sqrt (var + eps));
24
- auto w_shape = w.sizes ().vec ();
25
- w_shape.push_back (1 );
26
- w_shape.push_back (1 );
27
- w = w.reshape (w_shape);
28
- auto b = beta - gamma * (mean / at::sqrt (var + eps));
29
-
30
- auto weights = Weights (ctx, w);
31
- auto bias = Weights (ctx, b);
32
-
33
- auto bn_as_conv = ctx->net ->addConvolutionNd (*input, weights.num_output_maps , weights.kernel_shape , weights.data , bias.data );
34
- TRTORCH_CHECK (bn_as_conv, " Unable to create fused batch norm from node: " << *n);
35
-
36
- bn_as_conv->setName (util::node_info (n).c_str ());
37
-
38
- auto bn_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], bn_as_conv->getOutput (0 ));
39
- LOG_DEBUG (" Output tensor shape: " << bn_out->getDimensions ());
40
- return true ;
41
- }
42
-
43
- bool ConvertLinearBatchNorm (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
44
- auto input = args[0 ].ITensor ();
45
- auto shape = util::toVec (input->getDimensions ());
46
- auto gamma = args[1 ].unwrapToTensor (at::full ({shape},1 ));
47
- auto beta = args[2 ].unwrapToTensor (at::full ({shape},1 ));
48
- auto mean = args[3 ].unwrapToTensor (at::full ({shape},0 ));
49
- auto var = args[4 ].unwrapToTensor (at::full ({shape},0 ));
50
- LOG_WARNING (" Momentum argument is disregarded" );
51
- // auto momentum = args[6].unwrapToDouble(0);
52
- auto eps = args[7 ].unwrapToDouble (0 );
53
-
54
- auto mean_ = tensor_to_const (ctx, mean);
55
- auto bot_half = at::sqrt (var + eps);
56
- auto bot_half_ = tensor_to_const (ctx, bot_half);
57
- auto gamma_ = tensor_to_const (ctx, gamma);
58
- auto beta_ = tensor_to_const (ctx, beta);
59
-
60
- auto top_half = ctx->net ->addElementWise (*input, *mean_, nvinfer1::ElementWiseOperation::kSUB );
61
- auto top_half_out = top_half->getOutput (0 );
62
- auto x_hat = ctx->net ->addElementWise (*top_half_out, *bot_half_, nvinfer1::ElementWiseOperation::kDIV );
63
- auto x_hat_out = x_hat->getOutput (0 );
64
- auto bn_scaled = ctx->net ->addElementWise (*gamma_, *x_hat_out, nvinfer1::ElementWiseOperation::kPROD );
65
- auto bn_scaled_out = bn_scaled->getOutput (0 );
66
- auto bn_biased = ctx->net ->addElementWise (*beta_, *bn_scaled_out, nvinfer1::ElementWiseOperation::kSUM );
67
- auto bn_biased_out = bn_biased->getOutput (0 );
68
-
69
- bn_biased->setName (util::node_info (n).c_str ());
70
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], bn_biased_out);
71
-
72
- return true ;
73
- }
74
-
75
- volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
76
- .pattern({
77
- R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
78
- Tensor? mean, Tensor? var,
79
- bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
80
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
81
- auto input = args[0 ].ITensor ();
82
- auto shape = input->getDimensions ();
83
- auto gamma = args[1 ].unwrapToTensor ();
84
-
85
- if (/* training*/ args[5 ].unwrapToBool ()) {
86
- LOG_WARNING (R"WARN( TRTorch only converts forward pass of graphs, but saw training = True, may see
87
- unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN" );
88
- }
89
-
90
- // If gamma is None this fails
91
- if (util::volume (shape) == gamma.numel ()) {
92
- return ConvertLinearBatchNorm (ctx, n, args);
93
- } else {
94
- return ConvertConvBatchNorm (ctx, n, args);
95
- }
96
- }
97
- });
12
+ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
13
+ .pattern({
14
+ R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
15
+ Tensor? mean, Tensor? var,
16
+ bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
17
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18
+ auto input = args[0 ].ITensor ();
19
+ auto orig_shape = input->getDimensions ();
20
+ auto shape = util::toVec (orig_shape);
21
+ auto options = torch::TensorOptions ().dtype (torch::kFloat32 );
22
+ auto gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
23
+ auto beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
24
+ auto mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
25
+ auto var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
26
+ auto eps = args[7 ].unwrapToDouble (1e-5f );
27
+
28
+ LOG_DEBUG (" momentum disregarded" );
29
+ LOG_DEBUG (" training disregarded" );
30
+ LOG_DEBUG (" cudnn disregarded" );
31
+
32
+ auto should_unpack = util::toVec (orig_shape).size () < 4 ;
33
+ if (should_unpack) {
34
+ // expand spatial dims from 1D to 2D
35
+ auto new_shape = util::toDimsPad (util::toVec (orig_shape), 4 );
36
+ LOG_DEBUG (" Input shape is less than 4D got: " << orig_shape << " , inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);
37
+ auto in_shuffle = ctx->net ->addShuffle (*input);
38
+ in_shuffle->setReshapeDimensions (new_shape);
39
+ in_shuffle->setName (std::string (" [Reshape input to " + util::toStr (new_shape) + ' ]' ).c_str ());
40
+ input = in_shuffle->getOutput (0 );
41
+ }
42
+
43
+ auto scale = gamma / torch::sqrt (var + eps);
44
+ auto bias = beta - mean * scale;
45
+
46
+ auto scale_weights = Weights (ctx, scale);
47
+ auto bias_weights = Weights (ctx, bias);
48
+
49
+ auto bn = ctx->net ->addScaleNd (*input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , {}, 1 );
50
+ bn->setName (util::node_info (n).c_str ());
51
+ auto out_tensor = bn->getOutput (0 );
52
+
53
+ if (should_unpack) {
54
+ LOG_DEBUG (" Inserting shuffle layer to reshape to back to original shape: " << orig_shape);
55
+ auto out_shuffle = ctx->net ->addShuffle (*out_tensor);
56
+ out_shuffle->setReshapeDimensions (orig_shape);
57
+ out_shuffle->setName (std::string (" [Reshape output to " + util::toStr (orig_shape) + ' ]' ).c_str ());
58
+ out_tensor = out_shuffle->getOutput (0 );
59
+ }
60
+
61
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
62
+ return true ;
63
+ }
64
+ });
98
65
99
66
100
67
} // namespace
0 commit comments