1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include < llvm/ADT/STLExtras.h>
1615#define DEBUG_TYPE " decompose-lowering"
1716
18- #include < numeric>
1917#include < variant>
2018
2119#include " llvm/ADT/StringMap.h"
@@ -82,14 +80,15 @@ class QubitIndex {
8280class OpSignatureAnalyzer {
8381 public:
8482 OpSignatureAnalyzer () = delete ;
85- OpSignatureAnalyzer (CustomOp op, bool enableQregMode)
83+ OpSignatureAnalyzer (CustomOp op, bool enableQregMode, PatternRewriter &rewriter )
8684 : signature(OpSignature{
8785 .params = op.getParams (),
8886 .inQubits = op.getInQubits (),
8987 .inCtrlQubits = op.getInCtrlQubits (),
9088 .inCtrlValues = op.getInCtrlValues (),
9189 .outQubits = op.getOutQubits (),
9290 .outCtrlQubits = op.getOutCtrlQubits (),
91+ .rewriter = rewriter,
9392 })
9493 {
9594 if (!enableQregMode)
@@ -117,10 +116,6 @@ class OpSignatureAnalyzer {
117116 signature.inCtrlWireIndices .emplace_back (index);
118117 }
119118
120- // Output qubit indices are the same as input qubit indices
121- signature.outQubitIndices = signature.inWireIndices ;
122- signature.outCtrlQubitIndices = signature.inCtrlWireIndices ;
123-
124119 assert ((signature.inWireIndices .size () + signature.inCtrlWireIndices .size ()) > 0 &&
125120 " inWireIndices or inCtrlWireIndices should not be empty" );
126121
@@ -134,22 +129,39 @@ class OpSignatureAnalyzer {
134129 [refQreg](const QubitIndex &idx) { return idx.getReg () != refQreg; }) ||
135130 std::any_of (signature.inCtrlWireIndices .begin (), signature.inCtrlWireIndices .end (),
136131 [refQreg](const QubitIndex &idx) { return idx.getReg () != refQreg; });
132+
133+ // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
134+ // Since we will use the new qreg for the indices
135+ if (signature.needAllocQreg ) {
136+ for (auto [i, index] : llvm::enumerate (signature.inWireIndices )) {
137+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
138+ signature.inWireIndices [i] = QubitIndex (attr, index.getReg ());
139+ }
140+ for (auto [i, index] : llvm::enumerate (signature.inCtrlWireIndices )) {
141+ auto attr =
142+ IntegerAttr::get (rewriter.getI64Type (), i + signature.inWireIndices .size ());
143+ signature.inCtrlWireIndices [i] = QubitIndex (attr, index.getReg ());
144+ }
145+ }
146+
147+ // Output qubit indices are the same as input qubit indices
148+ signature.outQubitIndices = signature.inWireIndices ;
149+ signature.outCtrlQubitIndices = signature.inCtrlWireIndices ;
137150 }
138151
139152 operator bool () const { return isValid; }
140153
141154 Value getUpdatedQreg (PatternRewriter &rewriter, Location loc)
142155 {
143- Value updatedQreg = signature.inWireIndices [0 ].getReg ();
144156 if (signature.needAllocQreg ) {
145157 // allocate a new qreg with the number of qubits
146158 auto nqubits = signature.inWireIndices .size () + signature.inCtrlWireIndices .size ();
147159 IntegerAttr nqubitsAttr = IntegerAttr::get (rewriter.getI64Type (), nqubits);
148160 auto allocOp = rewriter.create <quantum::AllocOp>(
149161 loc, quantum::QuregType::get (rewriter.getContext ()), nullptr , nqubitsAttr);
150- updatedQreg = allocOp.getQreg ();
162+ return allocOp.getQreg ();
151163 }
152- return updatedQreg ;
164+ return signature. inWireIndices [ 0 ]. getReg () ;
153165 }
154166
155167 // Prepare the operands for calling the decomposition function
@@ -177,17 +189,16 @@ class OpSignatureAnalyzer {
177189
178190 for (auto [i, qubit] : llvm::enumerate (signature.inQubits )) {
179191 const QubitIndex &index = signature.inWireIndices [i];
192+ updatedQreg =
193+ rewriter.create <quantum::InsertOp>(loc, updatedQreg.getType (), updatedQreg,
194+ index.getValue (), index.getAttr (), qubit);
195+ }
180196
181- if (signature.needAllocQreg ) {
182- auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
183- updatedQreg = rewriter.create <quantum::InsertOp>(
184- loc, updatedQreg.getType (), updatedQreg, nullptr , attr, qubit);
185- }
186- else {
187- updatedQreg = rewriter.create <quantum::InsertOp>(loc, updatedQreg.getType (),
188- updatedQreg, index.getValue (),
189- index.getAttr (), qubit);
190- }
197+ for (auto [i, qubit] : llvm::enumerate (signature.inCtrlQubits )) {
198+ const QubitIndex &index = signature.inCtrlWireIndices [i];
199+ updatedQreg =
200+ rewriter.create <quantum::InsertOp>(loc, updatedQreg.getType (), updatedQreg,
201+ index.getValue (), index.getAttr (), qubit);
191202 }
192203
193204 operands[operandIdx++] = updatedQreg;
@@ -201,32 +212,15 @@ class OpSignatureAnalyzer {
201212 }
202213 }
203214
204- // preprocessing indices
205- // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
206- // instead of the original indices, since we will use the new qreg for the indices
207- auto inWireIndices = signature.inWireIndices ;
208- auto ctrlWireIndices = signature.inCtrlWireIndices ;
209- if (signature.needAllocQreg ) {
210- for (auto [i, index] : llvm::enumerate (inWireIndices)) {
211- auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
212- inWireIndices[i] = QubitIndex (attr, index.getReg ());
213- }
214- auto inWireIndicesSize = inWireIndices.size ();
215- for (auto [i, index] : llvm::enumerate (ctrlWireIndices)) {
216- auto attr = IntegerAttr::get (rewriter.getI64Type (), i + inWireIndicesSize);
217- ctrlWireIndices[i] = QubitIndex (attr, index.getReg ());
218- }
219- }
220-
221- if (!inWireIndices.empty ()) {
222- operands[operandIdx] =
223- fromTensorOrAsIs (inWireIndices, funcInputs[operandIdx], rewriter, loc);
215+ if (!signature.inWireIndices .empty ()) {
216+ operands[operandIdx] = fromTensorOrAsIs (signature.inWireIndices ,
217+ funcInputs[operandIdx], rewriter, loc);
224218 operandIdx++;
225219 }
226220
227- if (!ctrlWireIndices .empty ()) {
228- operands[operandIdx] =
229- fromTensorOrAsIs (ctrlWireIndices, funcInputs[operandIdx], rewriter, loc);
221+ if (!signature. inCtrlWireIndices .empty ()) {
222+ operands[operandIdx] = fromTensorOrAsIs (signature. inCtrlWireIndices ,
223+ funcInputs[operandIdx], rewriter, loc);
230224 operandIdx++;
231225 }
232226 }
@@ -274,26 +268,13 @@ class OpSignatureAnalyzer {
274268 SmallVector<Value> newResults;
275269 rewriter.setInsertionPointAfter (callOp);
276270
277- auto outQubitIndices = signature.outQubitIndices ;
278- auto outCtrlQubitIndices = signature.outCtrlQubitIndices ;
279- if (signature.needAllocQreg ) {
280- for (auto [i, index] : llvm::enumerate (outQubitIndices)) {
281- auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
282- outQubitIndices[i] = QubitIndex (attr, index.getReg ());
283- }
284- for (auto [i, index] : llvm::enumerate (outCtrlQubitIndices)) {
285- auto attr = IntegerAttr::get (rewriter.getI64Type (), i + outQubitIndices.size ());
286- outCtrlQubitIndices[i] = QubitIndex (attr, index.getReg ());
287- }
288- }
289-
290- for (const QubitIndex &index : outQubitIndices) {
271+ for (const QubitIndex &index : signature.outQubitIndices ) {
291272 auto extractOp = rewriter.create <quantum::ExtractOp>(
292273 callOp.getLoc (), rewriter.getType <quantum::QubitType>(), qreg, index.getValue (),
293274 index.getAttr ());
294275 newResults.emplace_back (extractOp.getResult ());
295276 }
296- for (const QubitIndex &index : outCtrlQubitIndices) {
277+ for (const QubitIndex &index : signature. outCtrlQubitIndices ) {
297278 auto extractOp = rewriter.create <quantum::ExtractOp>(
298279 callOp.getLoc (), rewriter.getType <quantum::QubitType>(), qreg, index.getValue (),
299280 index.getAttr ());
@@ -327,6 +308,9 @@ class OpSignatureAnalyzer {
327308 // Qreg mode specific information, if true, a new qreg should be allocated before function
328309 // call and deallocated after function call
329310 bool needAllocQreg = false ;
311+
312+ // Rewriter
313+ PatternRewriter &rewriter;
330314 } signature;
331315
332316 Value fromTensorOrAsIs (ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
@@ -499,10 +483,11 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
499483 " Decomposition function must have at least one result" );
500484
501485 auto enableQreg = isa<quantum::QuregType>(decompFunc.getFunctionType ().getInput (0 ));
502- auto analyzer = OpSignatureAnalyzer (op, enableQreg);
503- assert (analyzer && " Analyzer should be valid" );
504486
505487 rewriter.setInsertionPointAfter (op);
488+ auto analyzer = OpSignatureAnalyzer (op, enableQreg, rewriter);
489+ assert (analyzer && " Analyzer should be valid" );
490+
506491 auto callOperands = analyzer.prepareCallOperands (decompFunc, rewriter, op.getLoc ());
507492 auto callOp =
508493 rewriter.create <func::CallOp>(op.getLoc (), decompFunc.getFunctionType ().getResults (),
0 commit comments