Skip to content

Commit cdef0ff

Browse files
committed
Add dynamic qreg alloc for cross-qreg gate decomposition
1 parent 134e6f2 commit cdef0ff

File tree

2 files changed

+183
-26
lines changed

2 files changed

+183
-26
lines changed

mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <llvm/ADT/STLExtras.h>
1516
#define DEBUG_TYPE "decompose-lowering"
1617

18+
#include <numeric>
1719
#include <variant>
1820

1921
#include "llvm/ADT/StringMap.h"
@@ -40,6 +42,7 @@ namespace quantum {
4042
/// - A runtime Value (for dynamic indices computed at runtime)
4143
/// - An IntegerAttr (for compile-time constant indices)
4244
/// - Invalid/uninitialized (represented by std::monostate)
45+
/// And a qreg value to represent the qreg that the index belongs to
4346
///
4447
/// The struct uses std::variant to ensure only one type is active at a time,
4548
/// preventing invalid states.
@@ -54,17 +57,21 @@ namespace quantum {
5457
/// Value idx = dynamicIdx.getValue(); // Get the Value
5558
/// }
5659
/// }
57-
struct QubitIndex {
60+
class QubitIndex {
61+
private:
5862
// use monostate to represent the invalid index
5963
std::variant<std::monostate, Value, IntegerAttr> index;
64+
Value qreg;
6065

61-
QubitIndex() : index(std::monostate()) {}
62-
QubitIndex(Value val) : index(val) {}
63-
QubitIndex(IntegerAttr attr) : index(attr) {}
66+
public:
67+
QubitIndex() : index(std::monostate()), qreg(nullptr) {}
68+
QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {}
69+
QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {}
6470

6571
bool isValue() const { return std::holds_alternative<Value>(index); }
6672
bool isAttr() const { return std::holds_alternative<IntegerAttr>(index); }
6773
operator bool() const { return isValue() || isAttr(); }
74+
Value getReg() const { return qreg; }
6875
Value getValue() const { return isValue() ? std::get<Value>(index) : nullptr; }
6976
IntegerAttr getAttr() const { return isAttr() ? std::get<IntegerAttr>(index) : nullptr; }
7077
};
@@ -88,13 +95,6 @@ class OpSignatureAnalyzer {
8895
if (!enableQregMode)
8996
return;
9097

91-
signature.sourceQreg = getSourceQreg(signature.inQubits.front());
92-
if (!signature.sourceQreg) {
93-
op.emitError("Cannot get source qreg");
94-
isValid = false;
95-
return;
96-
}
97-
9898
// input wire indices
9999
for (Value qubit : signature.inQubits) {
100100
const QubitIndex index = getExtractIndex(qubit);
@@ -120,6 +120,20 @@ class OpSignatureAnalyzer {
120120
// Output qubit indices are the same as input qubit indices
121121
signature.outQubitIndices = signature.inWireIndices;
122122
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
123+
124+
assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 &&
125+
"inWireIndices or inCtrlWireIndices should not be empty");
126+
127+
// Get the first qreg as reference
128+
Value refQreg = !signature.inWireIndices.empty() ? signature.inWireIndices[0].getReg()
129+
: signature.inCtrlWireIndices[0].getReg();
130+
131+
// Check if any qreg is different
132+
signature.needAllocQreg =
133+
std::any_of(signature.inWireIndices.begin(), signature.inWireIndices.end(),
134+
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) ||
135+
std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(),
136+
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; });
123137
}
124138

125139
operator bool() const { return isValid; }
@@ -144,12 +158,30 @@ class OpSignatureAnalyzer {
144158

145159
int operandIdx = 0;
146160
if (isa<quantum::QuregType>(funcInputs[0])) {
147-
Value updatedQreg = signature.sourceQreg;
161+
// Allocate a new qreg if needed
162+
Value updatedQreg = signature.inWireIndices[0].getReg();
163+
if (signature.needAllocQreg) {
164+
// allocate a new qreg with the number of qubits
165+
auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size();
166+
IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits);
167+
auto allocOp = rewriter.create<quantum::AllocOp>(
168+
loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr);
169+
updatedQreg = allocOp.getQreg();
170+
}
171+
148172
for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
149173
const QubitIndex &index = signature.inWireIndices[i];
150-
updatedQreg =
151-
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
152-
index.getValue(), index.getAttr(), qubit);
174+
175+
if (signature.needAllocQreg) {
176+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
177+
updatedQreg = rewriter.create<quantum::InsertOp>(
178+
loc, updatedQreg.getType(), updatedQreg, nullptr, attr, qubit);
179+
}
180+
else {
181+
updatedQreg = rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(),
182+
updatedQreg, index.getValue(),
183+
index.getAttr(), qubit);
184+
}
153185
}
154186

155187
operands[operandIdx++] = updatedQreg;
@@ -163,15 +195,32 @@ class OpSignatureAnalyzer {
163195
}
164196
}
165197

166-
if (!signature.inWireIndices.empty()) {
167-
operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices,
168-
funcInputs[operandIdx], rewriter, loc);
198+
// preprocessing indices
199+
// If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
200+
// instead of the original indices, since we will use the new qreg for the indices
201+
auto wireIndices = signature.inWireIndices;
202+
auto ctrlWireIndices = signature.inCtrlWireIndices;
203+
if (signature.needAllocQreg) {
204+
for (auto [i, index] : llvm::enumerate(wireIndices)) {
205+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
206+
wireIndices[i] = QubitIndex(attr, index.getReg());
207+
}
208+
auto inWireIndicesSize = wireIndices.size();
209+
for (auto [i, index] : llvm::enumerate(ctrlWireIndices)) {
210+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i + inWireIndicesSize);
211+
ctrlWireIndices[i] = QubitIndex(attr, index.getReg());
212+
}
213+
}
214+
215+
if (!wireIndices.empty()) {
216+
operands[operandIdx] =
217+
fromTensorOrAsIs(wireIndices, funcInputs[operandIdx], rewriter, loc);
169218
operandIdx++;
170219
}
171220

172-
if (!signature.inCtrlWireIndices.empty()) {
173-
operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices,
174-
funcInputs[operandIdx], rewriter, loc);
221+
if (!ctrlWireIndices.empty()) {
222+
operands[operandIdx] =
223+
fromTensorOrAsIs(ctrlWireIndices, funcInputs[operandIdx], rewriter, loc);
175224
operandIdx++;
176225
}
177226
}
@@ -218,18 +267,37 @@ class OpSignatureAnalyzer {
218267

219268
SmallVector<Value> newResults;
220269
rewriter.setInsertionPointAfter(callOp);
221-
for (const QubitIndex &index : signature.outQubitIndices) {
270+
271+
auto outQubitIndices = signature.outQubitIndices;
272+
auto outCtrlQubitIndices = signature.outCtrlQubitIndices;
273+
if (signature.needAllocQreg) {
274+
for (auto [i, index] : llvm::enumerate(outQubitIndices)) {
275+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
276+
outQubitIndices[i] = QubitIndex(attr, index.getReg());
277+
}
278+
for (auto [i, index] : llvm::enumerate(outCtrlQubitIndices)) {
279+
auto attr = IntegerAttr::get(rewriter.getI64Type(), i + outQubitIndices.size());
280+
outCtrlQubitIndices[i] = QubitIndex(attr, index.getReg());
281+
}
282+
}
283+
284+
for (const QubitIndex &index : outQubitIndices) {
222285
auto extractOp = rewriter.create<quantum::ExtractOp>(
223286
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
224287
index.getAttr());
225288
newResults.emplace_back(extractOp.getResult());
226289
}
227-
for (const QubitIndex &index : signature.outCtrlQubitIndices) {
290+
for (const QubitIndex &index : outCtrlQubitIndices) {
228291
auto extractOp = rewriter.create<quantum::ExtractOp>(
229292
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
230293
index.getAttr());
231294
newResults.emplace_back(extractOp.getResult());
232295
}
296+
297+
if (signature.needAllocQreg) {
298+
rewriter.create<quantum::DeallocOp>(callOp.getLoc(), qreg);
299+
}
300+
233301
return newResults;
234302
}
235303

@@ -245,11 +313,14 @@ class OpSignatureAnalyzer {
245313
ValueRange outCtrlQubits;
246314

247315
// Qreg mode specific information
248-
Value sourceQreg = nullptr;
249316
SmallVector<QubitIndex> inWireIndices;
250317
SmallVector<QubitIndex> inCtrlWireIndices;
251318
SmallVector<QubitIndex> outQubitIndices;
252319
SmallVector<QubitIndex> outCtrlQubitIndices;
320+
321+
// Qreg mode specific information, if true, a new qreg should be allocated before function
322+
// call and deallocated after function call
323+
bool needAllocQreg = false;
253324
} signature;
254325

255326
Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
@@ -356,10 +427,10 @@ class OpSignatureAnalyzer {
356427
while (qubit) {
357428
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
358429
if (Value idx = extractOp.getIdx()) {
359-
return QubitIndex(idx);
430+
return QubitIndex(idx, extractOp.getQreg());
360431
}
361432
if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) {
362-
return QubitIndex(idxAttr);
433+
return QubitIndex(idxAttr, extractOp.getQreg());
363434
}
364435
}
365436

mlir/test/Quantum/DecomposeLoweringTest.mlir

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,92 @@ module @single_hadamard {
8484
}
8585
}
8686

87+
// -----
88+
89+
module @cz_hadamard {
90+
func.func public @test_cz_hadamard() -> tensor<2xf64> attributes {decompose_gatesets = [["CZ", "Hadamard"]]} {
91+
%cst = arith.constant dense<[0, 1]> : tensor<2xi64>
92+
%0 = quantum.alloc( 1) : !quantum.reg
93+
%1 = quantum.alloc( 1) : !quantum.reg
94+
95+
// Extract qubits from different qregs (this will trigger needAllocQreg)
96+
%2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit
97+
%3 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
98+
99+
// CHECK: [[CST:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64>
100+
// CHECK: [[REG0:%.+]] = quantum.alloc( 1) : !quantum.reg
101+
// CHECK: [[REG1:%.+]] = quantum.alloc( 1) : !quantum.reg
102+
// CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG1]][ 0] : !quantum.reg -> !quantum.bit
103+
// CHECK: [[QUBIT2:%.+]] = quantum.extract [[REG0]][ 0] : !quantum.reg -> !quantum.bit
104+
// CHECK: [[NEW_REG:%.+]] = quantum.alloc( 2) : !quantum.reg
105+
// CHECK: [[INSERT1:%.+]] = quantum.insert [[NEW_REG]][ 0], [[QUBIT1]] : !quantum.reg, !quantum.bit
106+
// CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][ 1], [[QUBIT2]] : !quantum.reg, !quantum.bit
107+
// CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64>
108+
// CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor<i64>
109+
// CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE1]][] : tensor<i64>
110+
// CHECK: [[EXTRACT1:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit
111+
// CHECK: [[H1:%.+]] = quantum.custom "Hadamard"() [[EXTRACT1]] : !quantum.bit
112+
// CHECK: [[SLICE2:%.+]] = stablehlo.slice [[CST]] [0:1] : (tensor<2xi64>) -> tensor<1xi64>
113+
// CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor<i64>
114+
// CHECK: [[INSERT_H:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[H1]] : !quantum.reg, !quantum.bit
115+
// CHECK: [[EXTRACTED_0:%.+]] = tensor.extract [[RESHAPE2]][] : tensor<i64>
116+
// CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED_0]]] : !quantum.reg -> !quantum.bit
117+
// CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit
118+
// CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[EXTRACT2]], [[EXTRACT3]] : !quantum.bit, !quantum.bit
119+
// CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT_H]][[[EXTRACTED_0]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit
120+
// CHECK: [[H2:%.+]] = quantum.custom "Hadamard"() [[CZ_RESULT]]#1 : !quantum.bit
121+
// CHECK: [[INSERT_CZ2:%.+]] = quantum.insert [[INSERT_CZ1]][[[EXTRACTED]]], [[H2]] : !quantum.reg, !quantum.bit
122+
// CHECK: [[FINAL_EXTRACT1:%.+]] = quantum.extract [[INSERT_CZ2]][ 0] : !quantum.reg -> !quantum.bit
123+
// CHECK: [[FINAL_EXTRACT2:%.+]] = quantum.extract [[INSERT_CZ2]][ 1] : !quantum.reg -> !quantum.bit
124+
// CHECK: quantum.dealloc [[INSERT_CZ2]] : !quantum.reg
125+
// CHECK-NOT: quantum.custom "CNOT"
126+
%out_qubits:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit
127+
128+
%4 = quantum.insert %1[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit
129+
quantum.dealloc %4 : !quantum.reg
130+
%5 = quantum.compbasis qubits %out_qubits#1 : !quantum.obs
131+
%6 = quantum.probs %5 : tensor<2xf64>
132+
%7 = quantum.insert %0[ 0], %out_qubits#1 : !quantum.reg, !quantum.bit
133+
quantum.dealloc %7 : !quantum.reg
134+
return %6 : tensor<2xf64>
135+
}
136+
137+
// Decomposition function for CNOT gate into CZ and Hadamard
138+
// CHECK-NOT: func.func private @cz_hadamard
139+
func.func private @cz_hadamard(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage<internal>} {
140+
%0 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
141+
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor<i64>
142+
%extracted = tensor.extract %1[] : tensor<i64>
143+
%2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
144+
%out_qubits = quantum.custom "Hadamard"() %2 : !quantum.bit
145+
%3 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64>
146+
%4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor<i64>
147+
%5 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
148+
%6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor<i64>
149+
%extracted_0 = tensor.extract %1[] : tensor<i64>
150+
%7 = quantum.insert %arg0[%extracted_0], %out_qubits : !quantum.reg, !quantum.bit
151+
%extracted_1 = tensor.extract %4[] : tensor<i64>
152+
%8 = quantum.extract %7[%extracted_1] : !quantum.reg -> !quantum.bit
153+
%extracted_2 = tensor.extract %6[] : tensor<i64>
154+
%9 = quantum.extract %7[%extracted_2] : !quantum.reg -> !quantum.bit
155+
%out_qubits_3:2 = quantum.custom "CZ"() %8, %9 : !quantum.bit, !quantum.bit
156+
%10 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
157+
%11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor<i64>
158+
%extracted_4 = tensor.extract %4[] : tensor<i64>
159+
%12 = quantum.insert %7[%extracted_4], %out_qubits_3#0 : !quantum.reg, !quantum.bit
160+
%extracted_5 = tensor.extract %6[] : tensor<i64>
161+
%13 = quantum.insert %12[%extracted_5], %out_qubits_3#1 : !quantum.reg, !quantum.bit
162+
%extracted_6 = tensor.extract %11[] : tensor<i64>
163+
%14 = quantum.extract %13[%extracted_6] : !quantum.reg -> !quantum.bit
164+
%out_qubits_7 = quantum.custom "Hadamard"() %14 : !quantum.bit
165+
%extracted_8 = tensor.extract %11[] : tensor<i64>
166+
%15 = quantum.insert %13[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit
167+
return %15 : !quantum.reg
168+
}
169+
}
170+
171+
172+
87173
// -----
88174
module @recursive {
89175
func.func public @test_recursive() -> tensor<4xf64> {

0 commit comments

Comments
 (0)