Skip to content

Commit 227dea3

Browse files
authored
Merge pull request #55 from NVIDIA/batch_norm_alt
fix(aten::batch_norm): A new batch norm implementation that hopefully doesnt have the same performace cost
2 parents 2a90fff + 6461872 commit 227dea3

File tree

5 files changed

+99
-87
lines changed

5 files changed

+99
-87
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 54 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "torch/torch.h"
12
#include "core/util/prelude.h"
23
#include "core/conversion/converters/converters.h"
34

@@ -8,93 +9,59 @@ namespace converters {
89
namespace impl {
910
namespace {
1011

11-
bool ConvertConvBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12-
auto input = args[0].ITensor();
13-
auto shape = util::toVec(input->getDimensions());
14-
LOG_WARNING("Assuming channel dimension is 3 because input is from a conv layer, please verify");
15-
auto gamma = args[1].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1));
16-
auto beta = args[2].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1));
17-
auto mean = args[3].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0));
18-
auto var = args[4].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0));
19-
LOG_WARNING("Momentum argument is disregarded");
20-
//auto momentum = args[6].unwrapToDouble(0);
21-
auto eps = args[7].unwrapToDouble(0);
22-
23-
auto w = at::diag(gamma / at::sqrt(var + eps));
24-
auto w_shape = w.sizes().vec();
25-
w_shape.push_back(1);
26-
w_shape.push_back(1);
27-
w = w.reshape(w_shape);
28-
auto b = beta - gamma * (mean / at::sqrt(var + eps));
29-
30-
auto weights = Weights(ctx, w);
31-
auto bias = Weights(ctx, b);
32-
33-
auto bn_as_conv = ctx->net->addConvolutionNd(*input, weights.num_output_maps, weights.kernel_shape, weights.data, bias.data);
34-
TRTORCH_CHECK(bn_as_conv, "Unable to create fused batch norm from node: " << *n);
35-
36-
bn_as_conv->setName(util::node_info(n).c_str());
37-
38-
auto bn_out = ctx->AssociateValueAndTensor(n->outputs()[0], bn_as_conv->getOutput(0));
39-
LOG_DEBUG("Output tensor shape: " << bn_out->getDimensions());
40-
return true;
41-
}
42-
43-
bool ConvertLinearBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
44-
auto input = args[0].ITensor();
45-
auto shape = util::toVec(input->getDimensions());
46-
auto gamma = args[1].unwrapToTensor(at::full({shape},1));
47-
auto beta = args[2].unwrapToTensor(at::full({shape},1));
48-
auto mean = args[3].unwrapToTensor(at::full({shape},0));
49-
auto var = args[4].unwrapToTensor(at::full({shape},0));
50-
LOG_WARNING("Momentum argument is disregarded");
51-
//auto momentum = args[6].unwrapToDouble(0);
52-
auto eps = args[7].unwrapToDouble(0);
53-
54-
auto mean_ = tensor_to_const(ctx, mean);
55-
auto bot_half = at::sqrt(var + eps);
56-
auto bot_half_ = tensor_to_const(ctx, bot_half);
57-
auto gamma_ = tensor_to_const(ctx, gamma);
58-
auto beta_ = tensor_to_const(ctx, beta);
59-
60-
auto top_half = ctx->net->addElementWise(*input, *mean_, nvinfer1::ElementWiseOperation::kSUB);
61-
auto top_half_out = top_half->getOutput(0);
62-
auto x_hat = ctx->net->addElementWise(*top_half_out, *bot_half_, nvinfer1::ElementWiseOperation::kDIV);
63-
auto x_hat_out = x_hat->getOutput(0);
64-
auto bn_scaled = ctx->net->addElementWise(*gamma_, *x_hat_out, nvinfer1::ElementWiseOperation::kPROD);
65-
auto bn_scaled_out = bn_scaled->getOutput(0);
66-
auto bn_biased = ctx->net->addElementWise(*beta_, *bn_scaled_out, nvinfer1::ElementWiseOperation::kSUM);
67-
auto bn_biased_out = bn_biased->getOutput(0);
68-
69-
bn_biased->setName(util::node_info(n).c_str());
70-
ctx->AssociateValueAndTensor(n->outputs()[0], bn_biased_out);
71-
72-
return true;
73-
}
74-
75-
volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
76-
.pattern({
77-
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
78-
Tensor? mean, Tensor? var,
79-
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
80-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
81-
auto input = args[0].ITensor();
82-
auto shape = input->getDimensions();
83-
auto gamma = args[1].unwrapToTensor();
84-
85-
if (/*training*/ args[5].unwrapToBool()) {
86-
LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see
87-
unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN");
88-
}
89-
90-
// If gamma is None this fails
91-
if (util::volume(shape) == gamma.numel()) {
92-
return ConvertLinearBatchNorm(ctx, n, args);
93-
} else {
94-
return ConvertConvBatchNorm(ctx, n, args);
95-
}
96-
}
97-
});
12+
auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
13+
.pattern({
14+
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
15+
Tensor? mean, Tensor? var,
16+
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
17+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18+
auto input = args[0].ITensor();
19+
auto orig_shape = input->getDimensions();
20+
auto shape = util::toVec(orig_shape);
21+
auto options = torch::TensorOptions().dtype(torch::kFloat32);
22+
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
23+
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
24+
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
25+
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
26+
auto eps = args[7].unwrapToDouble(1e-5f);
27+
28+
LOG_DEBUG("momentum disregarded");
29+
LOG_DEBUG("training disregarded");
30+
LOG_DEBUG("cudnn disregarded");
31+
32+
auto should_unpack = util::toVec(orig_shape).size() < 4;
33+
if (should_unpack) {
34+
// expand spatial dims from 1D to 2D
35+
auto new_shape = util::toDimsPad(util::toVec(orig_shape), 4);
36+
LOG_DEBUG("Input shape is less than 4D got: " << orig_shape << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);
37+
auto in_shuffle = ctx->net->addShuffle(*input);
38+
in_shuffle->setReshapeDimensions(new_shape);
39+
in_shuffle->setName(std::string("[Reshape input to " + util::toStr(new_shape) + ']').c_str());
40+
input = in_shuffle->getOutput(0);
41+
}
42+
43+
auto scale = gamma / torch::sqrt(var + eps);
44+
auto bias = beta - mean * scale;
45+
46+
auto scale_weights = Weights(ctx, scale);
47+
auto bias_weights = Weights(ctx, bias);
48+
49+
auto bn = ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, {}, 1);
50+
bn->setName(util::node_info(n).c_str());
51+
auto out_tensor = bn->getOutput(0);
52+
53+
if (should_unpack) {
54+
LOG_DEBUG("Inserting shuffle layer to reshape to back to original shape: " << orig_shape);
55+
auto out_shuffle = ctx->net->addShuffle(*out_tensor);
56+
out_shuffle->setReshapeDimensions(orig_shape);
57+
out_shuffle->setName(std::string("[Reshape output to " + util::toStr(orig_shape) + ']').c_str());
58+
out_tensor = out_shuffle->getOutput(0);
59+
}
60+
61+
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
62+
return true;
63+
}
64+
});
9865

9966

10067
} // namespace

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2727
passes::FuseFlattenLinear(g);
2828
passes::Conv2DToConvolution(g);
2929
passes::UnpackAddMM(g);
30+
//passes::UnpackBatchNorm(g);
3031
passes::UnpackLogSoftmax(g);
3132
//passes::RemoveDimExeception(g);
3233
//irfusers::UnpackBatchNorm(g);

core/lowering/passes/unpack_batch_norm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
4141
torch::jit::SubgraphRewriter unpack_batch_norm;
4242
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
4343
unpack_batch_norm.runOnGraph(graph);
44+
LOG_DEBUG("[Lowering Batch Norm]: momentum disregarded");
45+
LOG_DEBUG("[Lowering Batch Norm]: training disregarded");
46+
LOG_DEBUG("[Lowering Batch Norm]: cudnn disregarded");
4447
LOG_GRAPH("Post unpack batchnorm: " << *graph);
4548
}
4649
} // Namespace passes

tests/core/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ converter_test(
44
name = "test_activation"
55
)
66

7+
converter_test(
8+
name = "test_batch_norm"
9+
)
10+
711
converter_test(
812
name = "test_conv"
913
)
@@ -44,6 +48,7 @@ test_suite(
4448
name = "test_converters",
4549
tests = [
4650
":test_activation",
51+
":test_batch_norm",
4752
":test_conv",
4853
":test_element_wise",
4954
":test_linear",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenBatchNormConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%1: Float(5),
11+
%2: Float(5),
12+
%3: Float(5),
13+
%4: Float(5)):
14+
%5 : bool = prim::Constant[value=0]()
15+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
16+
%7 : float = prim::Constant[value=0.10000000000000001]()
17+
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
18+
return (%8))IR";
19+
20+
auto g = std::make_shared<torch::jit::Graph>();
21+
torch::jit::parseIR(graph, &*g);
22+
23+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
24+
auto gamma = at::randint(1, 10, {5}, {at::kCUDA});
25+
auto beta = at::randint(1, 10, {5}, {at::kCUDA});
26+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
27+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
28+
29+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
30+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
31+
32+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
33+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
34+
35+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
36+
}

0 commit comments

Comments
 (0)