From cdef0ffba5278c11269777f21ca4dd43aa5a4d76 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 30 Sep 2025 23:54:32 -0400 Subject: [PATCH 1/4] Add dynamic qreg alloc for cross-qreg gate decomposition --- .../Transforms/DecomposeLoweringPatterns.cpp | 123 ++++++++++++++---- mlir/test/Quantum/DecomposeLoweringTest.mlir | 86 ++++++++++++ 2 files changed, 183 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 9dcc4ea1ad..5b0d42c435 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #define DEBUG_TYPE "decompose-lowering" +#include #include #include "llvm/ADT/StringMap.h" @@ -40,6 +42,7 @@ namespace quantum { /// - A runtime Value (for dynamic indices computed at runtime) /// - An IntegerAttr (for compile-time constant indices) /// - Invalid/uninitialized (represented by std::monostate) +/// And a qreg value to represent the qreg that the index belongs to /// /// The struct uses std::variant to ensure only one type is active at a time, /// preventing invalid states. @@ -54,17 +57,21 @@ namespace quantum { /// Value idx = dynamicIdx.getValue(); // Get the Value /// } /// } -struct QubitIndex { +class QubitIndex { + private: // use monostate to represent the invalid index std::variant index; + Value qreg; - QubitIndex() : index(std::monostate()) {} - QubitIndex(Value val) : index(val) {} - QubitIndex(IntegerAttr attr) : index(attr) {} + public: + QubitIndex() : index(std::monostate()), qreg(nullptr) {} + QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {} + QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {} bool isValue() const { return std::holds_alternative(index); } bool isAttr() const { return std::holds_alternative(index); } operator bool() const { return isValue() || isAttr(); } + Value getReg() const { return qreg; } Value getValue() const { return isValue() ? std::get(index) : nullptr; } IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; } }; @@ -88,13 +95,6 @@ class OpSignatureAnalyzer { if (!enableQregMode) return; - signature.sourceQreg = getSourceQreg(signature.inQubits.front()); - if (!signature.sourceQreg) { - op.emitError("Cannot get source qreg"); - isValid = false; - return; - } - // input wire indices for (Value qubit : signature.inQubits) { const QubitIndex index = getExtractIndex(qubit); @@ -120,6 +120,20 @@ class OpSignatureAnalyzer { // Output qubit indices are the same as input qubit indices signature.outQubitIndices = signature.inWireIndices; signature.outCtrlQubitIndices = signature.inCtrlWireIndices; + + assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 && + "inWireIndices or inCtrlWireIndices should not be empty"); + + // Get the first qreg as reference + Value refQreg = !signature.inWireIndices.empty() ? signature.inWireIndices[0].getReg() + : signature.inCtrlWireIndices[0].getReg(); + + // Check if any qreg is different + signature.needAllocQreg = + std::any_of(signature.inWireIndices.begin(), signature.inWireIndices.end(), + [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) || + std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(), + [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }); } operator bool() const { return isValid; } @@ -144,12 +158,30 @@ class OpSignatureAnalyzer { int operandIdx = 0; if (isa(funcInputs[0])) { - Value updatedQreg = signature.sourceQreg; + // Allocate a new qreg if needed + Value updatedQreg = signature.inWireIndices[0].getReg(); + if (signature.needAllocQreg) { + // allocate a new qreg with the number of qubits + auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size(); + IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits); + auto allocOp = rewriter.create( + loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr); + updatedQreg = allocOp.getQreg(); + } + for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { const QubitIndex &index = signature.inWireIndices[i]; - updatedQreg = - rewriter.create(loc, updatedQreg.getType(), updatedQreg, - index.getValue(), index.getAttr(), qubit); + + if (signature.needAllocQreg) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i); + updatedQreg = rewriter.create( + loc, updatedQreg.getType(), updatedQreg, nullptr, attr, qubit); + } + else { + updatedQreg = rewriter.create(loc, updatedQreg.getType(), + updatedQreg, index.getValue(), + index.getAttr(), qubit); + } } operands[operandIdx++] = updatedQreg; @@ -163,15 +195,32 @@ class OpSignatureAnalyzer { } } - if (!signature.inWireIndices.empty()) { - operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices, - funcInputs[operandIdx], rewriter, loc); + // preprocessing indices + // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1 + // instead of the original indices, since we will use the new qreg for the indices + auto wireIndices = signature.inWireIndices; + auto ctrlWireIndices = signature.inCtrlWireIndices; + if (signature.needAllocQreg) { + for (auto [i, index] : llvm::enumerate(wireIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i); + wireIndices[i] = QubitIndex(attr, index.getReg()); + } + auto inWireIndicesSize = wireIndices.size(); + for (auto [i, index] : llvm::enumerate(ctrlWireIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i + inWireIndicesSize); + ctrlWireIndices[i] = QubitIndex(attr, index.getReg()); + } + } + + if (!wireIndices.empty()) { + operands[operandIdx] = + fromTensorOrAsIs(wireIndices, funcInputs[operandIdx], rewriter, loc); operandIdx++; } - if (!signature.inCtrlWireIndices.empty()) { - operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices, - funcInputs[operandIdx], rewriter, loc); + if (!ctrlWireIndices.empty()) { + operands[operandIdx] = + fromTensorOrAsIs(ctrlWireIndices, funcInputs[operandIdx], rewriter, loc); operandIdx++; } } @@ -218,18 +267,37 @@ class OpSignatureAnalyzer { SmallVector newResults; rewriter.setInsertionPointAfter(callOp); - for (const QubitIndex &index : signature.outQubitIndices) { + + auto outQubitIndices = signature.outQubitIndices; + auto outCtrlQubitIndices = signature.outCtrlQubitIndices; + if (signature.needAllocQreg) { + for (auto [i, index] : llvm::enumerate(outQubitIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i); + outQubitIndices[i] = QubitIndex(attr, index.getReg()); + } + for (auto [i, index] : llvm::enumerate(outCtrlQubitIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i + outQubitIndices.size()); + outCtrlQubitIndices[i] = QubitIndex(attr, index.getReg()); + } + } + + for (const QubitIndex &index : outQubitIndices) { auto extractOp = rewriter.create( callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), index.getAttr()); newResults.emplace_back(extractOp.getResult()); } - for (const QubitIndex &index : signature.outCtrlQubitIndices) { + for (const QubitIndex &index : outCtrlQubitIndices) { auto extractOp = rewriter.create( callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), index.getAttr()); newResults.emplace_back(extractOp.getResult()); } + + if (signature.needAllocQreg) { + rewriter.create(callOp.getLoc(), qreg); + } + return newResults; } @@ -245,11 +313,14 @@ class OpSignatureAnalyzer { ValueRange outCtrlQubits; // Qreg mode specific information - Value sourceQreg = nullptr; SmallVector inWireIndices; SmallVector inCtrlWireIndices; SmallVector outQubitIndices; SmallVector outCtrlQubitIndices; + + // Qreg mode specific information, if true, a new qreg should be allocated before function + // call and deallocated after function call + bool needAllocQreg = false; } signature; Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) @@ -356,10 +427,10 @@ class OpSignatureAnalyzer { while (qubit) { if (auto extractOp = qubit.getDefiningOp()) { if (Value idx = extractOp.getIdx()) { - return QubitIndex(idx); + return QubitIndex(idx, extractOp.getQreg()); } if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { - return QubitIndex(idxAttr); + return QubitIndex(idxAttr, extractOp.getQreg()); } } diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 91bfbe7778..e22c2e7df4 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -84,6 +84,92 @@ module @single_hadamard { } } +// ----- + +module @cz_hadamard { + func.func public @test_cz_hadamard() -> tensor<2xf64> attributes {decompose_gatesets = [["CZ", "Hadamard"]]} { + %cst = arith.constant dense<[0, 1]> : tensor<2xi64> + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.alloc( 1) : !quantum.reg + + // Extract qubits from different qregs (this will trigger needAllocQreg) + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> + // CHECK: [[REG0:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[REG1:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG1]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.extract [[REG0]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[NEW_REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[INSERT1:%.+]] = quantum.insert [[NEW_REG]][ 0], [[QUBIT1]] : !quantum.reg, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][ 1], [[QUBIT2]] : !quantum.reg, !quantum.bit + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[H1:%.+]] = quantum.custom "Hadamard"() [[EXTRACT1]] : !quantum.bit + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[CST]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[INSERT_H:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[H1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACTED_0:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED_0]]] : !quantum.reg -> !quantum.bit + // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[EXTRACT2]], [[EXTRACT3]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT_H]][[[EXTRACTED_0]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[H2:%.+]] = quantum.custom "Hadamard"() [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[INSERT_CZ2:%.+]] = quantum.insert [[INSERT_CZ1]][[[EXTRACTED]]], [[H2]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_EXTRACT1:%.+]] = quantum.extract [[INSERT_CZ2]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[FINAL_EXTRACT2:%.+]] = quantum.extract [[INSERT_CZ2]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: quantum.dealloc [[INSERT_CZ2]] : !quantum.reg + // CHECK-NOT: quantum.custom "CNOT" + %out_qubits:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit + + %4 = quantum.insert %1[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit + quantum.dealloc %4 : !quantum.reg + %5 = quantum.compbasis qubits %out_qubits#1 : !quantum.obs + %6 = quantum.probs %5 : tensor<2xf64> + %7 = quantum.insert %0[ 0], %out_qubits#1 : !quantum.reg, !quantum.bit + quantum.dealloc %7 : !quantum.reg + return %6 : tensor<2xf64> + } + + // Decomposition function for CNOT gate into CZ and Hadamard + // CHECK-NOT: func.func private @cz_hadamard + func.func private @cz_hadamard(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + %0 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %extracted = tensor.extract %1[] : tensor + %2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64> + %4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor + %5 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor + %extracted_0 = tensor.extract %1[] : tensor + %7 = quantum.insert %arg0[%extracted_0], %out_qubits : !quantum.reg, !quantum.bit + %extracted_1 = tensor.extract %4[] : tensor + %8 = quantum.extract %7[%extracted_1] : !quantum.reg -> !quantum.bit + %extracted_2 = tensor.extract %6[] : tensor + %9 = quantum.extract %7[%extracted_2] : !quantum.reg -> !quantum.bit + %out_qubits_3:2 = quantum.custom "CZ"() %8, %9 : !quantum.bit, !quantum.bit + %10 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor + %extracted_4 = tensor.extract %4[] : tensor + %12 = quantum.insert %7[%extracted_4], %out_qubits_3#0 : !quantum.reg, !quantum.bit + %extracted_5 = tensor.extract %6[] : tensor + %13 = quantum.insert %12[%extracted_5], %out_qubits_3#1 : !quantum.reg, !quantum.bit + %extracted_6 = tensor.extract %11[] : tensor + %14 = quantum.extract %13[%extracted_6] : !quantum.reg -> !quantum.bit + %out_qubits_7 = quantum.custom "Hadamard"() %14 : !quantum.bit + %extracted_8 = tensor.extract %11[] : tensor + %15 = quantum.insert %13[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit + return %15 : !quantum.reg + } +} + + + // ----- module @recursive { func.func public @test_recursive() -> tensor<4xf64> { From 0f1493331b56737706217233ac1d91b3dbe3c267 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 1 Oct 2025 00:11:29 -0400 Subject: [PATCH 2/4] too-complex-method --- .../Transforms/DecomposeLoweringPatterns.cpp | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 5b0d42c435..20aafef942 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -138,6 +138,20 @@ class OpSignatureAnalyzer { operator bool() const { return isValid; } + Value getUpdatedQreg(PatternRewriter &rewriter, Location loc) + { + Value updatedQreg = signature.inWireIndices[0].getReg(); + if (signature.needAllocQreg) { + // allocate a new qreg with the number of qubits + auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size(); + IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits); + auto allocOp = rewriter.create( + loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr); + updatedQreg = allocOp.getQreg(); + } + return updatedQreg; + } + // Prepare the operands for calling the decomposition function // There are two cases: // 1. The first input is a qreg, which means the decomposition function is a qreg mode function @@ -159,15 +173,7 @@ class OpSignatureAnalyzer { int operandIdx = 0; if (isa(funcInputs[0])) { // Allocate a new qreg if needed - Value updatedQreg = signature.inWireIndices[0].getReg(); - if (signature.needAllocQreg) { - // allocate a new qreg with the number of qubits - auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size(); - IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits); - auto allocOp = rewriter.create( - loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr); - updatedQreg = allocOp.getQreg(); - } + Value updatedQreg = getUpdatedQreg(rewriter, loc); for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { const QubitIndex &index = signature.inWireIndices[i]; @@ -198,23 +204,23 @@ class OpSignatureAnalyzer { // preprocessing indices // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1 // instead of the original indices, since we will use the new qreg for the indices - auto wireIndices = signature.inWireIndices; + auto inWireIndices = signature.inWireIndices; auto ctrlWireIndices = signature.inCtrlWireIndices; if (signature.needAllocQreg) { - for (auto [i, index] : llvm::enumerate(wireIndices)) { + for (auto [i, index] : llvm::enumerate(inWireIndices)) { auto attr = IntegerAttr::get(rewriter.getI64Type(), i); - wireIndices[i] = QubitIndex(attr, index.getReg()); + inWireIndices[i] = QubitIndex(attr, index.getReg()); } - auto inWireIndicesSize = wireIndices.size(); + auto inWireIndicesSize = inWireIndices.size(); for (auto [i, index] : llvm::enumerate(ctrlWireIndices)) { auto attr = IntegerAttr::get(rewriter.getI64Type(), i + inWireIndicesSize); ctrlWireIndices[i] = QubitIndex(attr, index.getReg()); } } - if (!wireIndices.empty()) { + if (!inWireIndices.empty()) { operands[operandIdx] = - fromTensorOrAsIs(wireIndices, funcInputs[operandIdx], rewriter, loc); + fromTensorOrAsIs(inWireIndices, funcInputs[operandIdx], rewriter, loc); operandIdx++; } From 38e22d02d797b8c01772eb1498138b68d49b7d87 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 1 Oct 2025 00:44:46 -0400 Subject: [PATCH 3/4] fix --- .../Transforms/DecomposeLoweringPatterns.cpp | 105 ++++++++---------- 1 file changed, 45 insertions(+), 60 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 20aafef942..a645289faa 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -12,10 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #define DEBUG_TYPE "decompose-lowering" -#include #include #include "llvm/ADT/StringMap.h" @@ -82,7 +80,7 @@ class QubitIndex { class OpSignatureAnalyzer { public: OpSignatureAnalyzer() = delete; - OpSignatureAnalyzer(CustomOp op, bool enableQregMode) + OpSignatureAnalyzer(CustomOp op, bool enableQregMode, PatternRewriter &rewriter) : signature(OpSignature{ .params = op.getParams(), .inQubits = op.getInQubits(), @@ -90,6 +88,7 @@ class OpSignatureAnalyzer { .inCtrlValues = op.getInCtrlValues(), .outQubits = op.getOutQubits(), .outCtrlQubits = op.getOutCtrlQubits(), + .rewriter = rewriter, }) { if (!enableQregMode) @@ -117,10 +116,6 @@ class OpSignatureAnalyzer { signature.inCtrlWireIndices.emplace_back(index); } - // Output qubit indices are the same as input qubit indices - signature.outQubitIndices = signature.inWireIndices; - signature.outCtrlQubitIndices = signature.inCtrlWireIndices; - assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 && "inWireIndices or inCtrlWireIndices should not be empty"); @@ -134,22 +129,39 @@ class OpSignatureAnalyzer { [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) || std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(), [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }); + + // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1 + // Since we will use the new qreg for the indices + if (signature.needAllocQreg) { + for (auto [i, index] : llvm::enumerate(signature.inWireIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i); + signature.inWireIndices[i] = QubitIndex(attr, index.getReg()); + } + for (auto [i, index] : llvm::enumerate(signature.inCtrlWireIndices)) { + auto attr = + IntegerAttr::get(rewriter.getI64Type(), i + signature.inWireIndices.size()); + signature.inCtrlWireIndices[i] = QubitIndex(attr, index.getReg()); + } + } + + // Output qubit indices are the same as input qubit indices + signature.outQubitIndices = signature.inWireIndices; + signature.outCtrlQubitIndices = signature.inCtrlWireIndices; } operator bool() const { return isValid; } Value getUpdatedQreg(PatternRewriter &rewriter, Location loc) { - Value updatedQreg = signature.inWireIndices[0].getReg(); if (signature.needAllocQreg) { // allocate a new qreg with the number of qubits auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size(); IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits); auto allocOp = rewriter.create( loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr); - updatedQreg = allocOp.getQreg(); + return allocOp.getQreg(); } - return updatedQreg; + return signature.inWireIndices[0].getReg(); } // Prepare the operands for calling the decomposition function @@ -177,17 +189,16 @@ class OpSignatureAnalyzer { for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { const QubitIndex &index = signature.inWireIndices[i]; + updatedQreg = + rewriter.create(loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); + } - if (signature.needAllocQreg) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), i); - updatedQreg = rewriter.create( - loc, updatedQreg.getType(), updatedQreg, nullptr, attr, qubit); - } - else { - updatedQreg = rewriter.create(loc, updatedQreg.getType(), - updatedQreg, index.getValue(), - index.getAttr(), qubit); - } + for (auto [i, qubit] : llvm::enumerate(signature.inCtrlQubits)) { + const QubitIndex &index = signature.inCtrlWireIndices[i]; + updatedQreg = + rewriter.create(loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); } operands[operandIdx++] = updatedQreg; @@ -201,32 +212,15 @@ class OpSignatureAnalyzer { } } - // preprocessing indices - // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1 - // instead of the original indices, since we will use the new qreg for the indices - auto inWireIndices = signature.inWireIndices; - auto ctrlWireIndices = signature.inCtrlWireIndices; - if (signature.needAllocQreg) { - for (auto [i, index] : llvm::enumerate(inWireIndices)) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), i); - inWireIndices[i] = QubitIndex(attr, index.getReg()); - } - auto inWireIndicesSize = inWireIndices.size(); - for (auto [i, index] : llvm::enumerate(ctrlWireIndices)) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), i + inWireIndicesSize); - ctrlWireIndices[i] = QubitIndex(attr, index.getReg()); - } - } - - if (!inWireIndices.empty()) { - operands[operandIdx] = - fromTensorOrAsIs(inWireIndices, funcInputs[operandIdx], rewriter, loc); + if (!signature.inWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices, + funcInputs[operandIdx], rewriter, loc); operandIdx++; } - if (!ctrlWireIndices.empty()) { - operands[operandIdx] = - fromTensorOrAsIs(ctrlWireIndices, funcInputs[operandIdx], rewriter, loc); + if (!signature.inCtrlWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices, + funcInputs[operandIdx], rewriter, loc); operandIdx++; } } @@ -274,26 +268,13 @@ class OpSignatureAnalyzer { SmallVector newResults; rewriter.setInsertionPointAfter(callOp); - auto outQubitIndices = signature.outQubitIndices; - auto outCtrlQubitIndices = signature.outCtrlQubitIndices; - if (signature.needAllocQreg) { - for (auto [i, index] : llvm::enumerate(outQubitIndices)) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), i); - outQubitIndices[i] = QubitIndex(attr, index.getReg()); - } - for (auto [i, index] : llvm::enumerate(outCtrlQubitIndices)) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), i + outQubitIndices.size()); - outCtrlQubitIndices[i] = QubitIndex(attr, index.getReg()); - } - } - - for (const QubitIndex &index : outQubitIndices) { + for (const QubitIndex &index : signature.outQubitIndices) { auto extractOp = rewriter.create( callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), index.getAttr()); newResults.emplace_back(extractOp.getResult()); } - for (const QubitIndex &index : outCtrlQubitIndices) { + for (const QubitIndex &index : signature.outCtrlQubitIndices) { auto extractOp = rewriter.create( callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), index.getAttr()); @@ -327,6 +308,9 @@ class OpSignatureAnalyzer { // Qreg mode specific information, if true, a new qreg should be allocated before function // call and deallocated after function call bool needAllocQreg = false; + + // Rewriter + PatternRewriter &rewriter; } signature; Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) @@ -499,10 +483,11 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern { "Decomposition function must have at least one result"); auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); - auto analyzer = OpSignatureAnalyzer(op, enableQreg); - assert(analyzer && "Analyzer should be valid"); rewriter.setInsertionPointAfter(op); + auto analyzer = OpSignatureAnalyzer(op, enableQreg, rewriter); + assert(analyzer && "Analyzer should be valid"); + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), From 2e5d3505c9e910a83a9700aabdde8c8697d9b0e9 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 1 Oct 2025 01:11:09 -0400 Subject: [PATCH 4/4] fix --- mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp | 7 ++++--- mlir/test/Quantum/DecomposeLoweringTest.mlir | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index a645289faa..7604780a2d 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -281,9 +281,10 @@ class OpSignatureAnalyzer { newResults.emplace_back(extractOp.getResult()); } - if (signature.needAllocQreg) { - rewriter.create(callOp.getLoc(), qreg); - } + // FIXME: Dealloc should be fine, but it will cause the error in lightning now + // if (signature.needAllocQreg) { + // rewriter.create(callOp.getLoc(), qreg); + // } return newResults; } diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index e22c2e7df4..b616ec2532 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -121,7 +121,6 @@ module @cz_hadamard { // CHECK: [[INSERT_CZ2:%.+]] = quantum.insert [[INSERT_CZ1]][[[EXTRACTED]]], [[H2]] : !quantum.reg, !quantum.bit // CHECK: [[FINAL_EXTRACT1:%.+]] = quantum.extract [[INSERT_CZ2]][ 0] : !quantum.reg -> !quantum.bit // CHECK: [[FINAL_EXTRACT2:%.+]] = quantum.extract [[INSERT_CZ2]][ 1] : !quantum.reg -> !quantum.bit - // CHECK: quantum.dealloc [[INSERT_CZ2]] : !quantum.reg // CHECK-NOT: quantum.custom "CNOT" %out_qubits:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit