Skip to content

Commit 21e6e12

Browse files
authored
Add support for Quantization/Dequantization (#4078)
The following changes have been introduced in the PR 1. Torch-mlir to stablehlo conversion for per tensor quantization/dequantization. 2. Torch-mlir to stablehlo conversion for per channel quantization/dequantization.
1 parent f5cfef8 commit 21e6e12

File tree

5 files changed

+359
-1
lines changed

5 files changed

+359
-1
lines changed

lib/Conversion/TorchToStablehlo/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
88
Reduction.cpp
99
Rng.cpp
1010
Pooling.cpp
11+
Uncategorized.cpp
1112
Utils.cpp
1213

1314
ADDITIONAL_HEADER_DIRS

lib/Conversion/TorchToStablehlo/PopulatePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ void populateRngOpPatternsAndLegality(TypeConverter &typeConverter,
6767
ConversionTarget &target,
6868
const TorchToStablehloOptions &options);
6969

70+
void populateUncategorizedPatternsAndLegality(
71+
TypeConverter &typeConverter, RewritePatternSet &patterns,
72+
ConversionTarget &target, const TorchToStablehloOptions &options);
73+
7074
} // namespace torch_to_stablehlo
7175
} // namespace torch
7276
} // namespace mlir

lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "PopulatePatterns.h"
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Quant/IR/Quant.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "stablehlo/dialect/ChloOps.h"
1819
#include "stablehlo/dialect/StablehloOps.h"
@@ -40,14 +41,15 @@ class ConvertTorchToStablehlo
4041
registry.insert<tensor::TensorDialect>();
4142
registry.insert<shape::ShapeDialect>();
4243
registry.insert<arith::ArithDialect>();
44+
registry.insert<quant::QuantDialect>();
4345
TorchConversion::getBackendTypeConversionDependentDialects(registry);
4446
}
4547
void runOnOperation() override {
4648
MLIRContext *context = &getContext();
4749
ConversionTarget target(*context);
4850
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
4951
tensor::TensorDialect, arith::ArithDialect,
50-
shape::ShapeDialect>();
52+
shape::ShapeDialect, quant::QuantDialect>();
5153

5254
TypeConverter typeConverter;
5355
typeConverter.addConversion([](Type type) { return type; });
@@ -72,6 +74,8 @@ class ConvertTorchToStablehlo
7274
typeConverter, patterns, target, options);
7375
torch_to_stablehlo::populateRngOpPatternsAndLegality(
7476
typeConverter, patterns, target, options);
77+
torch_to_stablehlo::populateUncategorizedPatternsAndLegality(
78+
typeConverter, patterns, target, options);
7579

7680
if (failed(applyPartialConversion(getOperation(), target,
7781
std::move(patterns)))) {
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/IR/BuiltinAttributes.h"
11+
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
12+
13+
#include "../PassDetail.h"
14+
#include "PopulatePatterns.h"
15+
#include "Utils.h"
16+
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
19+
#include "mlir/Dialect/Shape/IR/Shape.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "stablehlo/dialect/ChloOps.h"
22+
#include "stablehlo/dialect/StablehloOps.h"
23+
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
24+
#include "torch-mlir/Conversion/Utils/Utils.h"
25+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
26+
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
27+
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
28+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
29+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
30+
#include "llvm/ADT/APFloat.h"
31+
#include "llvm/ADT/APInt.h"
32+
#include "llvm/ADT/ArrayRef.h"
33+
#include "llvm/ADT/SmallVector.h"
34+
#include <cmath>
35+
#include <cstdint>
36+
#include <numeric>
37+
#include <type_traits>
38+
39+
using namespace mlir;
40+
using namespace mlir::torch;
41+
using namespace mlir::torch::Torch;
42+
using namespace mlir::torch::torch_to_stablehlo;
43+
44+
// AtenQuantizePerTensorOp
45+
// torch-mlir uses AtenQuantizePerTensorOp and AtenIntReprOp for per tensor
46+
// quantization. These two ops are processed and converted together to
47+
// stablehlo.uniform_quantize op.
48+
namespace {
49+
class ConvertAtenQuantizePerTensorOp
50+
: public OpConversionPattern<AtenQuantizePerTensorOp> {
51+
public:
52+
using OpConversionPattern::OpConversionPattern;
53+
54+
LogicalResult
55+
matchAndRewrite(AtenQuantizePerTensorOp op, OpAdaptor adaptor,
56+
ConversionPatternRewriter &rewriter) const override {
57+
auto *zeroPoint = op.getZeroPoint().getDefiningOp();
58+
if (!zeroPoint || !isa<ConstantIntOp>(zeroPoint)) {
59+
return failure();
60+
}
61+
auto zeroPointConstantOp = mlir::cast<ConstantIntOp>(zeroPoint);
62+
auto zeroPointValue = zeroPointConstantOp.getValueAttr().getInt();
63+
64+
auto scale = op.getScale().getDefiningOp();
65+
if (!scale || !isa<ConstantFloatOp>(scale)) {
66+
return failure();
67+
}
68+
69+
auto scaleConstantOp = mlir::cast<ConstantFloatOp>(scale);
70+
auto scaleValue =
71+
scaleConstantOp.getValueAttr().getValue().convertToDouble();
72+
73+
auto users = op.getResult().getUsers();
74+
auto opUser = *op.getResult().user_begin();
75+
if (!(std::distance(users.begin(), users.end()) == 1) ||
76+
!isa<AtenIntReprOp>(opUser)) {
77+
return failure();
78+
}
79+
80+
auto inputElemType =
81+
mlir::cast<RankedTensorType>(
82+
getTypeConverter()->convertType(op.getOperands().front().getType()))
83+
.getElementType();
84+
85+
mlir::Type dtype =
86+
cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
87+
int32_t bitWidth = 0;
88+
int32_t flags = quant::QuantizationFlags::FlagValue::Signed;
89+
if (isa<QUInt8Type>(dtype)) {
90+
bitWidth = 8;
91+
flags = 0;
92+
} else if (isa<QInt8Type>(dtype)) {
93+
bitWidth = 8;
94+
} else if (isa<QInt16Type>(dtype)) {
95+
bitWidth = 16;
96+
} else if (isa<QInt32Type>(dtype)) {
97+
bitWidth = 32;
98+
} else {
99+
return failure();
100+
}
101+
auto storageType = IntegerType::get(getContext(), bitWidth);
102+
103+
// Minimum and maximum values for unsigned integer.
104+
int64_t minValue = 0;
105+
int64_t maxValue = (1LL << bitWidth) - 1;
106+
// Update the minimum and maximum for signed integer.
107+
if (flags) {
108+
// For signed integers (2's complement representation)
109+
minValue = -(1LL << (bitWidth - 1));
110+
maxValue = (1LL << (bitWidth - 1)) - 1;
111+
}
112+
113+
auto qty = quant::UniformQuantizedType::get(
114+
flags, storageType, inputElemType, scaleValue, zeroPointValue, minValue,
115+
maxValue);
116+
117+
RankedTensorType outputType = cast<RankedTensorType>(
118+
getTypeConverter()->convertType(op->getResult(0).getType()));
119+
mlir::TensorType new_type = outputType.clone(qty);
120+
121+
stablehlo::UniformQuantizeOp qunatize =
122+
rewriter.replaceOpWithNewOp<stablehlo::UniformQuantizeOp>(
123+
opUser, new_type, adaptor.getOperands().front());
124+
125+
opUser->getResults().front().replaceAllUsesWith(
126+
qunatize->getResults().front());
127+
128+
rewriter.eraseOp(op);
129+
return success();
130+
}
131+
};
132+
} // namespace
133+
134+
// Aten_MakePerTensorQuantizedTensorOp
135+
// torch-mlir uses Aten_MakePerTensorQuantizedTensorOp and AtenDequantizeSelfOp
136+
// in pair to represent per channel dequantization. These two ops are converted
137+
// together to stablehlo.uniform_dequantize op
138+
namespace {
139+
class ConvertAten_MakePerTensorQuantizedTensorOp
140+
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
141+
public:
142+
using OpConversionPattern::OpConversionPattern;
143+
144+
LogicalResult
145+
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
146+
ConversionPatternRewriter &rewriter) const override {
147+
auto opUser = *op.getResult().user_begin();
148+
auto users = op.getResult().getUsers();
149+
if (!(std::distance(users.begin(), users.end()) == 1) ||
150+
!isa<AtenDequantizeSelfOp>(opUser)) {
151+
return failure();
152+
}
153+
// [TODO] verify that zeroPoint and Scale matches with the input operand
154+
// type.
155+
RankedTensorType outputType = cast<RankedTensorType>(
156+
getTypeConverter()->convertType(opUser->getResult(0).getType()));
157+
158+
rewriter.replaceOpWithNewOp<stablehlo::UniformDequantizeOp>(
159+
opUser, outputType, adaptor.getOperands().front());
160+
161+
rewriter.eraseOp(op);
162+
return success();
163+
}
164+
};
165+
} // namespace
166+
167+
namespace {
168+
class ConvertAtenQuantizePerChannelOp
169+
: public OpConversionPattern<AtenQuantizePerChannelOp> {
170+
public:
171+
using OpConversionPattern::OpConversionPattern;
172+
173+
LogicalResult
174+
matchAndRewrite(AtenQuantizePerChannelOp op, OpAdaptor adaptor,
175+
ConversionPatternRewriter &rewriter) const override {
176+
auto *zeroPoints = op.getZeroPoints().getDefiningOp();
177+
if (!zeroPoints || !isa<ValueTensorLiteralOp>(zeroPoints)) {
178+
return failure();
179+
}
180+
auto zeroPointsOp = mlir::cast<ValueTensorLiteralOp>(zeroPoints);
181+
182+
llvm::SmallVector<int64_t, 4> zeroPointsVec;
183+
for (auto zp : zeroPointsOp.getValue().getValues<llvm::APInt>()) {
184+
zeroPointsVec.emplace_back(zp.getSExtValue());
185+
}
186+
187+
auto scales = op.getScales().getDefiningOp();
188+
if (!scales || !isa<ValueTensorLiteralOp>(scales)) {
189+
return failure();
190+
}
191+
192+
llvm::SmallVector<double, 4> scalesVec;
193+
auto scalesOp = mlir::cast<ValueTensorLiteralOp>(scales);
194+
for (auto scale : scalesOp.getValue().getValues<llvm::APFloat>()) {
195+
scalesVec.emplace_back(scale.convertToDouble());
196+
}
197+
198+
auto axis = op.getAxis().getDefiningOp();
199+
if (!axis || !isa<ConstantIntOp>(axis)) {
200+
return failure();
201+
}
202+
auto axisOp = mlir::cast<ConstantIntOp>(axis);
203+
auto axisValue = axisOp.getValueAttr().getInt();
204+
205+
auto users = op.getResult().getUsers();
206+
auto opUser = *op.getResult().user_begin();
207+
if (!(std::distance(users.begin(), users.end()) == 1) ||
208+
!isa<AtenIntReprOp>(opUser)) {
209+
return failure();
210+
}
211+
212+
auto inputElemType =
213+
mlir::cast<RankedTensorType>(
214+
getTypeConverter()->convertType(op.getOperands().front().getType()))
215+
.getElementType();
216+
217+
mlir::Type dtype =
218+
cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
219+
int32_t bitWidth = 0;
220+
int32_t flags = quant::QuantizationFlags::FlagValue::Signed;
221+
if (isa<QUInt8Type>(dtype)) {
222+
bitWidth = 8;
223+
flags = 0;
224+
} else if (isa<QInt8Type>(dtype)) {
225+
bitWidth = 8;
226+
} else if (isa<QInt16Type>(dtype)) {
227+
bitWidth = 16;
228+
} else if (isa<QInt32Type>(dtype)) {
229+
bitWidth = 32;
230+
} else {
231+
return failure();
232+
}
233+
auto storageType = IntegerType::get(getContext(), bitWidth);
234+
235+
// Minimum and maximum values for unsigned integer.
236+
int64_t minValue = 0;
237+
int64_t maxValue = (1LL << bitWidth) - 1;
238+
// Update the minimum and maximum for signed integer.
239+
if (flags) {
240+
// For signed integers (2's complement representation)
241+
minValue = -(1LL << (bitWidth - 1));
242+
maxValue = (1LL << (bitWidth - 1)) - 1;
243+
}
244+
245+
auto qty = quant::UniformQuantizedPerAxisType::get(
246+
flags, storageType, inputElemType, scalesVec, zeroPointsVec, axisValue,
247+
minValue, maxValue);
248+
249+
RankedTensorType outputType = cast<RankedTensorType>(
250+
getTypeConverter()->convertType(op->getResult(0).getType()));
251+
mlir::TensorType new_type = outputType.clone(qty);
252+
253+
stablehlo::UniformQuantizeOp quantize =
254+
rewriter.replaceOpWithNewOp<stablehlo::UniformQuantizeOp>(
255+
opUser, new_type, adaptor.getOperands().front());
256+
257+
opUser->getResults().front().replaceAllUsesWith(
258+
quantize->getResults().front());
259+
260+
rewriter.eraseOp(op);
261+
return success();
262+
}
263+
};
264+
} // namespace
265+
266+
namespace {
267+
class ConvertAten_MakePerChannelQuantizedTensorOp
268+
: public OpConversionPattern<Aten_MakePerChannelQuantizedTensorOp> {
269+
public:
270+
using OpConversionPattern::OpConversionPattern;
271+
272+
LogicalResult
273+
matchAndRewrite(Aten_MakePerChannelQuantizedTensorOp op, OpAdaptor adaptor,
274+
ConversionPatternRewriter &rewriter) const override {
275+
auto opUser = *op.getResult().user_begin();
276+
auto users = op.getResult().getUsers();
277+
if (!(std::distance(users.begin(), users.end()) == 1) ||
278+
!isa<AtenDequantizeSelfOp>(opUser)) {
279+
return failure();
280+
}
281+
// [TODO] verify that zeroPoint and Scale matches with the input operand
282+
// type.
283+
RankedTensorType outputType = cast<RankedTensorType>(
284+
getTypeConverter()->convertType(opUser->getResult(0).getType()));
285+
286+
rewriter.replaceOpWithNewOp<stablehlo::UniformDequantizeOp>(
287+
opUser, outputType, adaptor.getOperands().front());
288+
289+
rewriter.eraseOp(op);
290+
return success();
291+
}
292+
};
293+
} // namespace
294+
295+
void mlir::torch::torch_to_stablehlo::populateUncategorizedPatternsAndLegality(
296+
TypeConverter &typeConverter, RewritePatternSet &patterns,
297+
ConversionTarget &target, const TorchToStablehloOptions &options) {
298+
MLIRContext *context = patterns.getContext();
299+
300+
target.addIllegalOp<AtenQuantizePerTensorOp>();
301+
target.addIllegalOp<AtenIntReprOp>();
302+
patterns.add<ConvertAtenQuantizePerTensorOp>(typeConverter, context);
303+
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
304+
target.addIllegalOp<AtenDequantizeSelfOp>();
305+
patterns.add<ConvertAten_MakePerTensorQuantizedTensorOp>(typeConverter,
306+
context);
307+
target.addIllegalOp<AtenQuantizePerChannelOp>();
308+
patterns.add<ConvertAtenQuantizePerChannelOp>(typeConverter, context);
309+
patterns.add<ConvertAten_MakePerChannelQuantizedTensorOp>(typeConverter,
310+
context);
311+
}

0 commit comments

Comments
 (0)