Skip to content

Commit 38e22d0

Browse files
committed
fix
1 parent 0f14933 commit 38e22d0

File tree

1 file changed

+45
-60
lines changed

1 file changed

+45
-60
lines changed

mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <llvm/ADT/STLExtras.h>
1615
#define DEBUG_TYPE "decompose-lowering"
1716

18-
#include <numeric>
1917
#include <variant>
2018

2119
#include "llvm/ADT/StringMap.h"
@@ -82,14 +80,15 @@ class QubitIndex {
8280
class OpSignatureAnalyzer {
8381
public:
8482
OpSignatureAnalyzer() = delete;
85-
OpSignatureAnalyzer(CustomOp op, bool enableQregMode)
83+
OpSignatureAnalyzer(CustomOp op, bool enableQregMode, PatternRewriter &rewriter)
8684
: signature(OpSignature{
8785
.params = op.getParams(),
8886
.inQubits = op.getInQubits(),
8987
.inCtrlQubits = op.getInCtrlQubits(),
9088
.inCtrlValues = op.getInCtrlValues(),
9189
.outQubits = op.getOutQubits(),
9290
.outCtrlQubits = op.getOutCtrlQubits(),
91+
.rewriter = rewriter,
9392
})
9493
{
9594
if (!enableQregMode)
@@ -117,10 +116,6 @@ class OpSignatureAnalyzer {
117116
signature.inCtrlWireIndices.emplace_back(index);
118117
}
119118

120-
// Output qubit indices are the same as input qubit indices
121-
signature.outQubitIndices = signature.inWireIndices;
122-
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
123-
124119
assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 &&
125120
"inWireIndices or inCtrlWireIndices should not be empty");
126121

@@ -134,22 +129,39 @@ class OpSignatureAnalyzer {
134129
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) ||
135130
std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(),
136131
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; });
132+
133+
// If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
134+
// Since we will use the new qreg for the indices
135+
if (signature.needAllocQreg) {
136+
for (auto [i, index] : llvm::enumerate(signature.inWireIndices)) {
137+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
138+
signature.inWireIndices[i] = QubitIndex(attr, index.getReg());
139+
}
140+
for (auto [i, index] : llvm::enumerate(signature.inCtrlWireIndices)) {
141+
auto attr =
142+
IntegerAttr::get(rewriter.getI64Type(), i + signature.inWireIndices.size());
143+
signature.inCtrlWireIndices[i] = QubitIndex(attr, index.getReg());
144+
}
145+
}
146+
147+
// Output qubit indices are the same as input qubit indices
148+
signature.outQubitIndices = signature.inWireIndices;
149+
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
137150
}
138151

139152
operator bool() const { return isValid; }
140153

141154
Value getUpdatedQreg(PatternRewriter &rewriter, Location loc)
142155
{
143-
Value updatedQreg = signature.inWireIndices[0].getReg();
144156
if (signature.needAllocQreg) {
145157
// allocate a new qreg with the number of qubits
146158
auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size();
147159
IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits);
148160
auto allocOp = rewriter.create<quantum::AllocOp>(
149161
loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr);
150-
updatedQreg = allocOp.getQreg();
162+
return allocOp.getQreg();
151163
}
152-
return updatedQreg;
164+
return signature.inWireIndices[0].getReg();
153165
}
154166

155167
// Prepare the operands for calling the decomposition function
@@ -177,17 +189,16 @@ class OpSignatureAnalyzer {
177189

178190
for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
179191
const QubitIndex &index = signature.inWireIndices[i];
192+
updatedQreg =
193+
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
194+
index.getValue(), index.getAttr(), qubit);
195+
}
180196

181-
if (signature.needAllocQreg) {
182-
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
183-
updatedQreg = rewriter.create<quantum::InsertOp>(
184-
loc, updatedQreg.getType(), updatedQreg, nullptr, attr, qubit);
185-
}
186-
else {
187-
updatedQreg = rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(),
188-
updatedQreg, index.getValue(),
189-
index.getAttr(), qubit);
190-
}
197+
for (auto [i, qubit] : llvm::enumerate(signature.inCtrlQubits)) {
198+
const QubitIndex &index = signature.inCtrlWireIndices[i];
199+
updatedQreg =
200+
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
201+
index.getValue(), index.getAttr(), qubit);
191202
}
192203

193204
operands[operandIdx++] = updatedQreg;
@@ -201,32 +212,15 @@ class OpSignatureAnalyzer {
201212
}
202213
}
203214

204-
// preprocessing indices
205-
// If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
206-
// instead of the original indices, since we will use the new qreg for the indices
207-
auto inWireIndices = signature.inWireIndices;
208-
auto ctrlWireIndices = signature.inCtrlWireIndices;
209-
if (signature.needAllocQreg) {
210-
for (auto [i, index] : llvm::enumerate(inWireIndices)) {
211-
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
212-
inWireIndices[i] = QubitIndex(attr, index.getReg());
213-
}
214-
auto inWireIndicesSize = inWireIndices.size();
215-
for (auto [i, index] : llvm::enumerate(ctrlWireIndices)) {
216-
auto attr = IntegerAttr::get(rewriter.getI64Type(), i + inWireIndicesSize);
217-
ctrlWireIndices[i] = QubitIndex(attr, index.getReg());
218-
}
219-
}
220-
221-
if (!inWireIndices.empty()) {
222-
operands[operandIdx] =
223-
fromTensorOrAsIs(inWireIndices, funcInputs[operandIdx], rewriter, loc);
215+
if (!signature.inWireIndices.empty()) {
216+
operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices,
217+
funcInputs[operandIdx], rewriter, loc);
224218
operandIdx++;
225219
}
226220

227-
if (!ctrlWireIndices.empty()) {
228-
operands[operandIdx] =
229-
fromTensorOrAsIs(ctrlWireIndices, funcInputs[operandIdx], rewriter, loc);
221+
if (!signature.inCtrlWireIndices.empty()) {
222+
operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices,
223+
funcInputs[operandIdx], rewriter, loc);
230224
operandIdx++;
231225
}
232226
}
@@ -274,26 +268,13 @@ class OpSignatureAnalyzer {
274268
SmallVector<Value> newResults;
275269
rewriter.setInsertionPointAfter(callOp);
276270

277-
auto outQubitIndices = signature.outQubitIndices;
278-
auto outCtrlQubitIndices = signature.outCtrlQubitIndices;
279-
if (signature.needAllocQreg) {
280-
for (auto [i, index] : llvm::enumerate(outQubitIndices)) {
281-
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
282-
outQubitIndices[i] = QubitIndex(attr, index.getReg());
283-
}
284-
for (auto [i, index] : llvm::enumerate(outCtrlQubitIndices)) {
285-
auto attr = IntegerAttr::get(rewriter.getI64Type(), i + outQubitIndices.size());
286-
outCtrlQubitIndices[i] = QubitIndex(attr, index.getReg());
287-
}
288-
}
289-
290-
for (const QubitIndex &index : outQubitIndices) {
271+
for (const QubitIndex &index : signature.outQubitIndices) {
291272
auto extractOp = rewriter.create<quantum::ExtractOp>(
292273
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
293274
index.getAttr());
294275
newResults.emplace_back(extractOp.getResult());
295276
}
296-
for (const QubitIndex &index : outCtrlQubitIndices) {
277+
for (const QubitIndex &index : signature.outCtrlQubitIndices) {
297278
auto extractOp = rewriter.create<quantum::ExtractOp>(
298279
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
299280
index.getAttr());
@@ -327,6 +308,9 @@ class OpSignatureAnalyzer {
327308
// Qreg mode specific information, if true, a new qreg should be allocated before function
328309
// call and deallocated after function call
329310
bool needAllocQreg = false;
311+
312+
// Rewriter
313+
PatternRewriter &rewriter;
330314
} signature;
331315

332316
Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
@@ -499,10 +483,11 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
499483
"Decomposition function must have at least one result");
500484

501485
auto enableQreg = isa<quantum::QuregType>(decompFunc.getFunctionType().getInput(0));
502-
auto analyzer = OpSignatureAnalyzer(op, enableQreg);
503-
assert(analyzer && "Analyzer should be valid");
504486

505487
rewriter.setInsertionPointAfter(op);
488+
auto analyzer = OpSignatureAnalyzer(op, enableQreg, rewriter);
489+
assert(analyzer && "Analyzer should be valid");
490+
506491
auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
507492
auto callOp =
508493
rewriter.create<func::CallOp>(op.getLoc(), decompFunc.getFunctionType().getResults(),

0 commit comments

Comments
 (0)