Skip to content

Commit 5068f80

Browse files
authored
Merge pull request #51 from NVIDIA/docs
Adds a documenation site for TRTorch
2 parents 272ef40 + 9f3188f commit 5068f80

File tree

414 files changed

+123778
-35
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

414 files changed

+123778
-35
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ cpp/ptq/datasets/data/
2424
tests/accuracy/datasets/data/*
2525
._.DS_Store
2626
*.tar.gz
27-
*.tgz
27+
*.tgz
28+
docsrc/_build
29+
docsrc/_api
30+
docsrc/_tmp

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2525
torch::jit::FuseLinear(g);
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
28+
passes::Conv2DToConvolution(g);
2829
passes::UnpackAddMM(g);
2930
passes::UnpackLogSoftmax(g);
3031
//passes::RemoveDimExeception(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cc_library(
66
"passes.h",
77
],
88
srcs = [
9+
"conv2d_to_convolution.cpp",
910
"exception_elimination.cpp",
1011
"fuse_flatten_linear.cpp",
1112
"remove_dropout.cpp",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv2d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=1]()
18+
%2 : int[] = prim::Constant[value=[0, 0]]()
19+
%3 : bool = prim::Constant[value=0]()
20+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %3)
21+
return (%4))IR";;
22+
23+
// replace matmul + add pattern to linear
24+
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
25+
map_conv2d_to_convolution.RegisterRewritePattern(
26+
conv2d_pattern, convolution_pattern);
27+
map_conv2d_to_convolution.runOnGraph(graph);
28+
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
29+
}
30+
31+
} // namespace passes
32+
} // namespace lowering
33+
} // namespace core
34+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10+
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1011
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1112
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1213
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);

cpp/api/include/trtorch/logging.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@ namespace logging {
99
* Emum for setting message severity
1010
*/
1111
enum Level {
12-
kINTERNAL_ERROR, // Only print messages for internal errors
13-
kERROR, // Print all internal errors and errors (default)
14-
kWARNING, // Print warnings and errors
15-
kINFO, // Print all info, warnings and errors
16-
kDEBUG, // Print all debug info, info, warnings and errors
17-
kGRAPH, // Print everything including the intermediate graphs of the lowering phase
12+
/// Only print messages for internal errors
13+
kINTERNAL_ERROR,
14+
/// Print all internal errors and errors (default)
15+
kERROR,
16+
/// Print warnings and errors
17+
kWARNING,
18+
/// Print all info, warnings and errors
19+
kINFO,
20+
/// Print all debug info, info, warnings and errors
21+
kDEBUG,
22+
/// Print everything including the intermediate graphs of the lowering phase
23+
kGRAPH,
1824
};
1925

2026
// Are these ones necessary for the user?
@@ -37,11 +43,15 @@ TRTORCH_API void set_is_colored_output_on(bool colored_output_on);
3743

3844
/**
3945
* @brief Get the current reportable log level
46+
*
47+
* @return TRTORCH_API get_reportable_log_level
4048
*/
4149
TRTORCH_API Level get_reportable_log_level();
4250

4351
/**
4452
* @brief Is colored output enabled?
53+
*
54+
* @return TRTORCH_API get_is_colored_output_on
4555
*/
4656
TRTORCH_API bool get_is_colored_output_on();
4757

cpp/api/include/trtorch/ptq.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <iostream>
77
#include <sstream>
88

9+
#ifndef DOXYGEN_SHOULD_SKIP_THIS
910
namespace nvinfer1 {
1011
class IInt8Calibrator;
1112
class IInt8EntropyCalibrator2;
@@ -17,18 +18,40 @@ template<typename Example>
1718
class Iterator;
1819
}
1920
}
21+
#endif //DOXYGEN_SHOULD_SKIP_THIS
2022

2123
namespace trtorch {
2224
namespace ptq {
2325

26+
/**
27+
* @brief Generic Int8Calibrator implementation based on a specified
28+
* TensorRT calibration algorithm and a LibTorch DataLoader
29+
*
30+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
31+
* @tparam DataLoaderUniquePtr: std::unique_ptr<torch::data::DataLoader> - DataLoader type
32+
*/
2433
template<typename Algorithm, typename DataLoaderUniquePtr>
2534
class Int8Calibrator : Algorithm {
2635
using DataLoader = typename DataLoaderUniquePtr::element_type;
2736
using Batch = typename DataLoader::super::BatchType;
2837
public:
38+
/**
39+
* @brief Construct a new Int8Calibrator object
40+
*
41+
* Using the provided DataLoader, construct a calibrator that can be used for PTQ with TRTorch
42+
*
43+
* @param dataloader: std::unqiue_ptr<torch::data::DataLoader> - A unique pointer to the DataLoader, should be what is returned from the make_data_loader factory
44+
* @param cache_file_path: const std::string& - A path to store / find the calibration cache
45+
* @param use_cache : bool - Whether to use the cache (if it exists)
46+
*/
2947
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
3048
: dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
3149

50+
/**
51+
* @brief Get the Batch Size for the next batch (always 1 due to issues with TRT and explicit batch)
52+
*
53+
* @return int
54+
*/
3255
int getBatchSize() const override {
3356
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
3457
// work when reporting the batch size here and having explicity batching.
@@ -37,6 +60,15 @@ class Int8Calibrator : Algorithm {
3760
//return static_cast<int>(dataloader_->options().batch_size);
3861
}
3962

63+
/**
64+
* @brief Get the next Batch
65+
*
66+
* @param bindings: void*[] - An array of binding pointers (fed in from TensorRT calibrator), these buffers should be filed with batch data for each input
67+
* @param names: const char*[] - Names of bindings
68+
* @param nbBindings: int - Number of bindings
69+
* @return true - There is a new batch for the calibrator to consume
70+
* @return false - There is not a new batch for the calibrator to consume
71+
*/
4072
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
4173
// HACK: doesnt seem like the first try in the initializer list works
4274
if (! it_created_) {
@@ -60,6 +92,14 @@ class Int8Calibrator : Algorithm {
6092
return true;
6193
}
6294

95+
/**
96+
* @brief Read calibration cache
97+
*
98+
* How to read from the calibration cache, only enabled if use_cache is set
99+
*
100+
* @param length
101+
* @return const void* - Pointer to cache data
102+
*/
63103
const void* readCalibrationCache(size_t& length) override {
64104
if (use_cache_) {
65105
std::stringstream ss;
@@ -81,6 +121,14 @@ class Int8Calibrator : Algorithm {
81121
return nullptr;
82122
}
83123

124+
/**
125+
* @brief Write calibration cache
126+
*
127+
* Write a the calibration cache provided by TensorRT to a specified file
128+
*
129+
* @param cache: const void* - cache data
130+
* @param length: size_t - length of cache
131+
*/
84132
void writeCalibrationCache(const void* cache, size_t length) override {
85133
std::ofstream cache_file(cache_file_path_, std::ios::binary);
86134
cache_file.write(reinterpret_cast<const char*>(cache), length);
@@ -89,37 +137,87 @@ class Int8Calibrator : Algorithm {
89137
logging::log(logging::Level::kINFO, ss.str());
90138
}
91139

140+
/**
141+
* @brief operator to cast to nvinfer1::IInt8Calibrator*
142+
*
143+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
144+
*
145+
* @return nvinfer1::IInt8Calibrator*
146+
*/
92147
operator nvinfer1::IInt8Calibrator* () {
93148
return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
94149
}
95150

96151
private:
152+
/// Pointer to the dataloader
97153
DataLoader* dataloader_;
154+
/// Iterator used to traverse the dataloader
98155
torch::data::Iterator<Batch> it_;
156+
/// Path to cache file
99157
const std::string& cache_file_path_;
158+
/// Size of cache
100159
size_t cache_size_ = 0;
160+
/// Whether to use the cache or not
101161
bool use_cache_;
162+
/// Cache data
102163
std::vector<char> cache_;
164+
/// If the iterator has been created, DataLoaders can only have 1 live iterator,
165+
/// due to some issues this cannot be created at construction, so it is set in the first
166+
/// batch, controlled by this flag
103167
bool it_created_ = false;
104168
};
105169

170+
/**
171+
* @brief Generic Int8Calibrator implementation based on a specified
172+
* TensorRT calibration algorithm that only reads from a calibration file
173+
*
174+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
175+
*/
106176
template<typename Algorithm>
107177
class Int8CacheCalibrator : Algorithm {
108178
public:
179+
/**
180+
* @brief Construct a new Int 8 Cache Calibrator object
181+
*
182+
* @param cache_file_path
183+
*/
109184
Int8CacheCalibrator(const std::string& cache_file_path)
110185
: cache_file_path_(cache_file_path) {}
111186

187+
/**
188+
* @brief Get the Batch Size for the next batch (always 1 due to issues with TRT and explicit batch)
189+
*
190+
* @return int
191+
*/
112192
int getBatchSize() const override {
113193
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
114194
// work when reporting the batch size here and having explicity batching.
115195
// So we just report batch size 1 (warnings will still be printed out).
116196
return 1;
117197
}
118198

199+
/**
200+
* @brief Get the next Batch
201+
*
202+
* Not used always returns false
203+
*
204+
* @param bindings: void*[] - An array of binding pointers (fed in from TensorRT calibrator), these buffers should be filed with batch data for each input
205+
* @param names: const char*[] - Names of bindings
206+
* @param nbBindings: int - Number of bindings
207+
* @return false
208+
*/
119209
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
120210
return false;
121211
}
122212

213+
/**
214+
* @brief Read calibration cache
215+
*
216+
* How to read from the calibration cache, only enabled if use_cache is set
217+
*
218+
* @param length
219+
* @return const void* - Pointer to cache data
220+
*/
123221
const void* readCalibrationCache(size_t& length) override {
124222
std::stringstream ss;
125223
ss << "Reading Calibration Cache from " << cache_file_path_;
@@ -143,6 +241,15 @@ class Int8CacheCalibrator : Algorithm {
143241
return cache_size_ ? cache_.data() : nullptr;
144242
}
145243

244+
245+
/**
246+
* @brief Write calibration cache
247+
*
248+
* Write a the calibration cache provided by TensorRT to a specified file
249+
*
250+
* @param cache: const void* - cache data
251+
* @param length: size_t - length of cache
252+
*/
146253
void writeCalibrationCache(const void* cache, size_t length) override {
147254
std::ofstream cache_file(cache_file_path_, std::ios::binary);
148255
cache_file.write(reinterpret_cast<const char*>(cache), length);
@@ -151,13 +258,23 @@ class Int8CacheCalibrator : Algorithm {
151258
logging::log(logging::Level::kINFO, ss.str());
152259
}
153260

261+
/**
262+
* @brief operator to cast to nvinfer1::IInt8Calibrator*
263+
*
264+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
265+
*
266+
* @return nvinfer1::IInt8Calibrator*
267+
*/
154268
operator nvinfer1::IInt8Calibrator* () {
155269
return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
156270
}
157271

158272
private:
273+
/// Path to cache file
159274
const std::string& cache_file_path_;
275+
/// Size of cache
160276
size_t cache_size_ = 0;
277+
/// Cache data
161278
std::vector<char> cache_;
162279
};
163280

0 commit comments

Comments
 (0)