1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ #include < llvm/ADT/STLExtras.h>
1516#define DEBUG_TYPE " decompose-lowering"
1617
18+ #include < numeric>
1719#include < variant>
1820
1921#include " llvm/ADT/StringMap.h"
@@ -40,6 +42,7 @@ namespace quantum {
4042// / - A runtime Value (for dynamic indices computed at runtime)
4143// / - An IntegerAttr (for compile-time constant indices)
4244// / - Invalid/uninitialized (represented by std::monostate)
45+ // / And a qreg value to represent the qreg that the index belongs to
4346// /
4447// / The struct uses std::variant to ensure only one type is active at a time,
4548// / preventing invalid states.
@@ -54,17 +57,21 @@ namespace quantum {
5457// / Value idx = dynamicIdx.getValue(); // Get the Value
5558// / }
5659// / }
57- struct QubitIndex {
60+ class QubitIndex {
61+ private:
5862 // use monostate to represent the invalid index
5963 std::variant<std::monostate, Value, IntegerAttr> index;
64+ Value qreg;
6065
61- QubitIndex () : index(std::monostate()) {}
62- QubitIndex (Value val) : index(val) {}
63- QubitIndex (IntegerAttr attr) : index(attr) {}
66+ public:
67+ QubitIndex () : index(std::monostate()), qreg(nullptr ) {}
68+ QubitIndex (Value val, Value qreg) : index(val), qreg(qreg) {}
69+ QubitIndex (IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {}
6470
6571 bool isValue () const { return std::holds_alternative<Value>(index); }
6672 bool isAttr () const { return std::holds_alternative<IntegerAttr>(index); }
6773 operator bool () const { return isValue () || isAttr (); }
74+ Value getReg () const { return qreg; }
6875 Value getValue () const { return isValue () ? std::get<Value>(index) : nullptr ; }
6976 IntegerAttr getAttr () const { return isAttr () ? std::get<IntegerAttr>(index) : nullptr ; }
7077};
@@ -88,13 +95,6 @@ class OpSignatureAnalyzer {
8895 if (!enableQregMode)
8996 return ;
9097
91- signature.sourceQreg = getSourceQreg (signature.inQubits .front ());
92- if (!signature.sourceQreg ) {
93- op.emitError (" Cannot get source qreg" );
94- isValid = false ;
95- return ;
96- }
97-
9898 // input wire indices
9999 for (Value qubit : signature.inQubits ) {
100100 const QubitIndex index = getExtractIndex (qubit);
@@ -120,6 +120,20 @@ class OpSignatureAnalyzer {
120120 // Output qubit indices are the same as input qubit indices
121121 signature.outQubitIndices = signature.inWireIndices ;
122122 signature.outCtrlQubitIndices = signature.inCtrlWireIndices ;
123+
124+ assert ((signature.inWireIndices .size () + signature.inCtrlWireIndices .size ()) > 0 &&
125+ " inWireIndices or inCtrlWireIndices should not be empty" );
126+
127+ // Get the first qreg as reference
128+ Value refQreg = !signature.inWireIndices .empty () ? signature.inWireIndices [0 ].getReg ()
129+ : signature.inCtrlWireIndices [0 ].getReg ();
130+
131+ // Check if any qreg is different
132+ signature.needAllocQreg =
133+ std::any_of (signature.inWireIndices .begin (), signature.inWireIndices .end (),
134+ [refQreg](const QubitIndex &idx) { return idx.getReg () != refQreg; }) ||
135+ std::any_of (signature.inCtrlWireIndices .begin (), signature.inCtrlWireIndices .end (),
136+ [refQreg](const QubitIndex &idx) { return idx.getReg () != refQreg; });
123137 }
124138
125139 operator bool () const { return isValid; }
@@ -144,12 +158,30 @@ class OpSignatureAnalyzer {
144158
145159 int operandIdx = 0 ;
146160 if (isa<quantum::QuregType>(funcInputs[0 ])) {
147- Value updatedQreg = signature.sourceQreg ;
161+ // Allocate a new qreg if needed
162+ Value updatedQreg = signature.inWireIndices [0 ].getReg ();
163+ if (signature.needAllocQreg ) {
164+ // allocate a new qreg with the number of qubits
165+ auto nqubits = signature.inWireIndices .size () + signature.inCtrlWireIndices .size ();
166+ IntegerAttr nqubitsAttr = IntegerAttr::get (rewriter.getI64Type (), nqubits);
167+ auto allocOp = rewriter.create <quantum::AllocOp>(
168+ loc, quantum::QuregType::get (rewriter.getContext ()), nullptr , nqubitsAttr);
169+ updatedQreg = allocOp.getQreg ();
170+ }
171+
148172 for (auto [i, qubit] : llvm::enumerate (signature.inQubits )) {
149173 const QubitIndex &index = signature.inWireIndices [i];
150- updatedQreg =
151- rewriter.create <quantum::InsertOp>(loc, updatedQreg.getType (), updatedQreg,
152- index.getValue (), index.getAttr (), qubit);
174+
175+ if (signature.needAllocQreg ) {
176+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
177+ updatedQreg = rewriter.create <quantum::InsertOp>(
178+ loc, updatedQreg.getType (), updatedQreg, nullptr , attr, qubit);
179+ }
180+ else {
181+ updatedQreg = rewriter.create <quantum::InsertOp>(loc, updatedQreg.getType (),
182+ updatedQreg, index.getValue (),
183+ index.getAttr (), qubit);
184+ }
153185 }
154186
155187 operands[operandIdx++] = updatedQreg;
@@ -163,15 +195,32 @@ class OpSignatureAnalyzer {
163195 }
164196 }
165197
166- if (!signature.inWireIndices .empty ()) {
167- operands[operandIdx] = fromTensorOrAsIs (signature.inWireIndices ,
168- funcInputs[operandIdx], rewriter, loc);
198+ // preprocessing indices
199+ // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
200+ // instead of the original indices, since we will use the new qreg for the indices
201+ auto wireIndices = signature.inWireIndices ;
202+ auto ctrlWireIndices = signature.inCtrlWireIndices ;
203+ if (signature.needAllocQreg ) {
204+ for (auto [i, index] : llvm::enumerate (wireIndices)) {
205+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
206+ wireIndices[i] = QubitIndex (attr, index.getReg ());
207+ }
208+ auto inWireIndicesSize = wireIndices.size ();
209+ for (auto [i, index] : llvm::enumerate (ctrlWireIndices)) {
210+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i + inWireIndicesSize);
211+ ctrlWireIndices[i] = QubitIndex (attr, index.getReg ());
212+ }
213+ }
214+
215+ if (!wireIndices.empty ()) {
216+ operands[operandIdx] =
217+ fromTensorOrAsIs (wireIndices, funcInputs[operandIdx], rewriter, loc);
169218 operandIdx++;
170219 }
171220
172- if (!signature. inCtrlWireIndices .empty ()) {
173- operands[operandIdx] = fromTensorOrAsIs (signature. inCtrlWireIndices ,
174- funcInputs[operandIdx], rewriter, loc);
221+ if (!ctrlWireIndices .empty ()) {
222+ operands[operandIdx] =
223+ fromTensorOrAsIs (ctrlWireIndices, funcInputs[operandIdx], rewriter, loc);
175224 operandIdx++;
176225 }
177226 }
@@ -218,18 +267,37 @@ class OpSignatureAnalyzer {
218267
219268 SmallVector<Value> newResults;
220269 rewriter.setInsertionPointAfter (callOp);
221- for (const QubitIndex &index : signature.outQubitIndices ) {
270+
271+ auto outQubitIndices = signature.outQubitIndices ;
272+ auto outCtrlQubitIndices = signature.outCtrlQubitIndices ;
273+ if (signature.needAllocQreg ) {
274+ for (auto [i, index] : llvm::enumerate (outQubitIndices)) {
275+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i);
276+ outQubitIndices[i] = QubitIndex (attr, index.getReg ());
277+ }
278+ for (auto [i, index] : llvm::enumerate (outCtrlQubitIndices)) {
279+ auto attr = IntegerAttr::get (rewriter.getI64Type (), i + outQubitIndices.size ());
280+ outCtrlQubitIndices[i] = QubitIndex (attr, index.getReg ());
281+ }
282+ }
283+
284+ for (const QubitIndex &index : outQubitIndices) {
222285 auto extractOp = rewriter.create <quantum::ExtractOp>(
223286 callOp.getLoc (), rewriter.getType <quantum::QubitType>(), qreg, index.getValue (),
224287 index.getAttr ());
225288 newResults.emplace_back (extractOp.getResult ());
226289 }
227- for (const QubitIndex &index : signature. outCtrlQubitIndices ) {
290+ for (const QubitIndex &index : outCtrlQubitIndices) {
228291 auto extractOp = rewriter.create <quantum::ExtractOp>(
229292 callOp.getLoc (), rewriter.getType <quantum::QubitType>(), qreg, index.getValue (),
230293 index.getAttr ());
231294 newResults.emplace_back (extractOp.getResult ());
232295 }
296+
297+ if (signature.needAllocQreg ) {
298+ rewriter.create <quantum::DeallocOp>(callOp.getLoc (), qreg);
299+ }
300+
233301 return newResults;
234302 }
235303
@@ -245,11 +313,14 @@ class OpSignatureAnalyzer {
245313 ValueRange outCtrlQubits;
246314
247315 // Qreg mode specific information
248- Value sourceQreg = nullptr ;
249316 SmallVector<QubitIndex> inWireIndices;
250317 SmallVector<QubitIndex> inCtrlWireIndices;
251318 SmallVector<QubitIndex> outQubitIndices;
252319 SmallVector<QubitIndex> outCtrlQubitIndices;
320+
321+ // Qreg mode specific information, if true, a new qreg should be allocated before function
322+ // call and deallocated after function call
323+ bool needAllocQreg = false ;
253324 } signature;
254325
255326 Value fromTensorOrAsIs (ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
@@ -356,10 +427,10 @@ class OpSignatureAnalyzer {
356427 while (qubit) {
357428 if (auto extractOp = qubit.getDefiningOp <quantum::ExtractOp>()) {
358429 if (Value idx = extractOp.getIdx ()) {
359- return QubitIndex (idx);
430+ return QubitIndex (idx, extractOp. getQreg () );
360431 }
361432 if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr ()) {
362- return QubitIndex (idxAttr);
433+ return QubitIndex (idxAttr, extractOp. getQreg () );
363434 }
364435 }
365436
0 commit comments