Skip to content

Commit e8d458d

Browse files
committed
refactor(//core/conversion): relax restrictions on input tensor types
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent cbed1b9 commit e8d458d

File tree

6 files changed

+10
-47
lines changed

6 files changed

+10
-47
lines changed

core/conversion/InterfaceTypes.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ namespace conversion {
1010
GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs,
1111
std::vector<at::Tensor> params) {
1212
GraphParams named_params;
13-
auto type_lut = torch::jit::script::string_to_type_lut();
1413
auto param_it = params.begin();
1514
for (auto in : inputs) {
16-
if (in->type() != type_lut["Tensor"] \
15+
if (in->type() != c10::TensorType::get() \
1716
&& in->isCompleteTensor() && param_it != params.end()) {
1817
named_params[in] = *param_it;
1918
++param_it;
@@ -35,7 +34,7 @@ InputRange::InputRange(std::vector<int64_t> d) {
3534
min = util::toDims(d);
3635
max = util::toDims(d);
3736
input_shape = util::toDims(d);
38-
37+
3938
}
4039

4140

@@ -48,14 +47,14 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
4847
sizes.insert(min_shape.size());
4948
sizes.insert(opt_shape.size());
5049
sizes.insert(max_shape.size());
51-
50+
5251
if (sizes.size() != 1) {
5352
LOG_ERROR("Expected all input sizes have the same dimensions, but found dimensions: min(" \
5453
<< min_shape.size() << "), opt("
5554
<< opt_shape.size() << "), max("
5655
<< max_shape.size() << ")");
5756
}
58-
57+
5958
min = util::toDimsPad(min_shape, 4);
6059
opt = util::toDimsPad(opt_shape, 4);
6160
max = util::toDimsPad(max_shape, 4);
@@ -72,9 +71,9 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
7271
dyn_shape.push_back(opt_shape[i]);
7372
}
7473
}
75-
74+
7675
input_shape = util::toDimsPad(dyn_shape, 4);
77-
76+
7877
}
7978

8079
} // namespace conversion

core/conversion/conversion.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,14 @@ void AddInputs(ConversionCtx* ctx,
117117
at::ArrayRef<const torch::jit::Value*> inputs,
118118
std::vector<InputRange>& input_dims) {
119119

120-
auto type_lut = torch::jit::script::string_to_type_lut();
121120
std::vector<const torch::jit::Value*> input_tensors;
122121
for (auto in : inputs) {
123122
// Disregarding inputs that are not tensors
124123
//
125124
// Ex.
126125
// self.1:__torch__.alexnet -> ignored
127126
// input.1:Tensor -> used
128-
auto pt_type = in->type();
129-
if (pt_type == type_lut["Tensor"]) {
127+
if (in->type()->isSubtypeOf(c10::TensorType::get()) && ctx->evaluated_value_map.find(in) == ctx->evaluated_value_map.end()) {
130128
input_tensors.push_back(in);
131129
}
132130
}

core/conversion/conversion.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,6 @@
66
#include "torch/csrc/jit/ir/ir.h"
77
#include "core/conversion/conversionctx/ConversionCtx.h"
88

9-
namespace torch {
10-
namespace jit {
11-
namespace script {
12-
const std::unordered_map<std::string, c10::TypePtr>& string_to_type_lut();
13-
}
14-
}
15-
}
16-
179
namespace trtorch {
1810
namespace core {
1911
namespace conversion {

core/conversion/conversion_blacklist.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
1919
"prim::device",
2020
"prim::GetAttr",
2121
"prim::CallMethod",
22+
"prim::Drop",
2223
"aten:dropout",
2324
};
2425
return nonconvertable_nodes;

core/conversion/string_to_type_lut.cpp

Lines changed: 0 additions & 27 deletions
This file was deleted.

core/execution/register_trt_op.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
2222
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
2323
auto shape = core::util::toVec(dims);
2424
contig_inputs.push_back(inputs[i].view(shape).contiguous());
25-
LOG_DEBUG("In shape:" << shape);
25+
LOG_DEBUG("In shape: " << shape);
2626
ctx->setBindingDimensions(i, dims);
2727
gpu_handles.push_back(contig_inputs.back().data_ptr());
2828
}
@@ -32,7 +32,7 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
3232
std::vector<at::Tensor> outputs;
3333
for (uint64_t o = inputs.size(); o < (io.first + io.second); o++) {
3434
auto out_shape = ctx->getBindingDimensions(o);
35-
//LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape);
35+
LOG_DEBUG("Output shape: " << out_shape);
3636
auto dims = core::util::toVec(out_shape);
3737
auto type = util::toATenDType(ctx->getEngine().getBindingDataType(o));
3838
outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous());

0 commit comments

Comments
 (0)