Skip to content

Commit f3cbe5b

Browse files
committed
♻️ Final improvements: type checks, parameter ordering, pipeline test
1 parent 36b19fe commit f3cbe5b

File tree

5 files changed

+232
-176
lines changed

5 files changed

+232
-176
lines changed

mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
1010
set(LIBRARIES ${dialect_libs} MQT::CoreIR)
1111
add_compile_options(-fexceptions)
1212

13-
message(STATUS "MLIR_DIALECT_LIBS contains: ${dialect_libs}")
14-
1513
file(GLOB_RECURSE TRANSFORMS_SOURCES *.cpp)
1614

1715
add_mlir_library(MLIRQCOTransforms ${TRANSFORMS_SOURCES} LINK_LIBS ${LIBRARIES} DEPENDS

mlir/lib/Dialect/QCO/Transforms/QuaternionMergeRotationGates.cpp

Lines changed: 67 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
#include "mlir/Dialect/QCO/IR/QCODialect.h"
1212
#include "mlir/Dialect/QCO/Transforms/Passes.h"
1313

14-
#include <algorithm>
15-
#include <array>
1614
#include <cassert>
1715
#include <cmath>
1816
#include <cstdint>
17+
#include <llvm/ADT/TypeSwitch.h>
1918
#include <llvm/Support/ErrorHandling.h>
2019
#include <mlir/Dialect/Arith/IR/Arith.h>
2120
#include <mlir/Dialect/Math/IR/Math.h>
@@ -26,7 +25,7 @@
2625
#include <mlir/Support/LLVM.h>
2726
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
2827
#include <numbers>
29-
#include <string_view>
28+
#include <optional>
3029
#include <utility>
3130

3231
namespace mlir::qco {
@@ -52,17 +51,14 @@ struct MergeRotationGatesPattern final
5251

5352
enum class RotationAxis : std::uint8_t { X, Y, Z };
5453

55-
static constexpr std::array<std::string_view, 4> MERGEABLE_GATES = {
56-
"u", "rx", "ry", "rz"};
57-
5854
/**
5955
* @brief Checks if an operation is a mergeable rotation gate (rx, ry, rz, u).
6056
*
61-
* @param name Name of the operation to check
57+
* @param op The operation to check
6258
* @return True if mergeable, false otherwise
6359
*/
64-
static bool isMergeable(std::string_view name) {
65-
return std::ranges::find(MERGEABLE_GATES, name) != MERGEABLE_GATES.end();
60+
static bool isMergeable(mlir::Operation* op) {
61+
return mlir::isa<RXOp, RYOp, RZOp, UOp>(op);
6662
}
6763

6864
/**
@@ -78,13 +74,28 @@ struct MergeRotationGatesPattern final
7874
*/
7975
[[nodiscard]] static bool areQuaternionMergeable(mlir::Operation& a,
8076
mlir::Operation& b) {
81-
const auto aName = a.getName().stripDialect().str();
82-
const auto bName = b.getName().stripDialect().str();
83-
84-
if (!(isMergeable(aName) && isMergeable(bName))) {
77+
if (!isMergeable(&a) || !isMergeable(&b)) {
8578
return false;
8679
}
87-
return (aName != bName) || (aName == "u" && bName == "u");
80+
81+
// Different gate types OR both are U gates
82+
return (a.getName() != b.getName()) ||
83+
(mlir::isa<UOp>(a) && mlir::isa<UOp>(b));
84+
}
85+
86+
/**
87+
* @brief Returns the rotation axis for a single-axis rotation gate.
88+
*
89+
* @param op The operation to query
90+
* @return The rotation axis, or std::nullopt if the operation is not a
91+
* single-axis rotation gate (RX, RY, RZ)
92+
*/
93+
static std::optional<RotationAxis> getRotationAxis(mlir::Operation* op) {
94+
return llvm::TypeSwitch<mlir::Operation*, std::optional<RotationAxis>>(op)
95+
.Case<RXOp>([](auto) { return RotationAxis::X; })
96+
.Case<RYOp>([](auto) { return RotationAxis::Y; })
97+
.Case<RZOp>([](auto) { return RotationAxis::Z; })
98+
.Default([](auto) { return std::nullopt; });
8899
}
89100

90101
/**
@@ -130,6 +141,8 @@ struct MergeRotationGatesPattern final
130141
case RotationAxis::Z:
131142
return {.w = cos, .x = zero, .y = zero, .z = sin};
132143
} // NOLINT(bugprone-branch-clone): false positive, branches differ
144+
145+
llvm_unreachable("Invalid rotation axis");
133146
}
134147

135148
/**
@@ -142,24 +155,15 @@ struct MergeRotationGatesPattern final
142155
*/
143156
static Quaternion quaternionFromRotation(UnitaryOpInterface op,
144157
mlir::PatternRewriter& rewriter) {
145-
auto const type = op->getName().stripDialect().str();
146-
147-
if (type == "u") {
158+
if (mlir::isa<UOp>(op)) {
148159
return quaternionFromUGate(op, rewriter);
149160
}
150161

151-
auto loc = op->getLoc();
152-
auto angle = op.getParameter(0);
153-
154-
if (type == "rx") {
155-
return createAxisQuaternion(angle, RotationAxis::X, loc, rewriter);
156-
}
157-
if (type == "ry") {
158-
return createAxisQuaternion(angle, RotationAxis::Y, loc, rewriter);
159-
}
160-
if (type == "rz") {
161-
return createAxisQuaternion(angle, RotationAxis::Z, loc, rewriter);
162+
if (auto axis = getRotationAxis(op.getOperation())) {
163+
return createAxisQuaternion(op.getParameter(0), *axis, op->getLoc(),
164+
rewriter);
162165
}
166+
163167
llvm_unreachable("Unsupported operation type");
164168
}
165169

@@ -227,14 +231,14 @@ struct MergeRotationGatesPattern final
227231
/**
228232
* @brief Converts a u-gate to quaternion representation.
229233
*
230-
* U(alpha, beta, gamma) uses ZYZ decomposition: RZ(alpha) -> RY(beta) ->
231-
* RZ(gamma).
234+
* U(theta, phi, lambda) uses ZYZ decomposition: RZ(lambda) -> RY(theta) ->
235+
* RZ(phi).
232236
*
233237
* When composing rotations, quaternion multiplication follows matrix
234238
* multiplication order (right-to-left), which is the reverse of the
235239
* application sequence:
236-
* Sequential application: RZ(alpha), then RY(beta), then RZ(gamma)
237-
* Quaternion product: Qgamma * Qbeta * Qalpha
240+
* Sequential application: RZ(lambda), then RY(theta), then RZ(phi)
241+
* Quaternion product: qPhi * qTheta * qLambda
238242
*
239243
* @param op The u-gate operation to convert
240244
* @param rewriter Pattern rewriter for creating new operations
@@ -245,26 +249,27 @@ struct MergeRotationGatesPattern final
245249
auto loc = op->getLoc();
246250

247251
// U gate uses ZYZ decomposition:
248-
// U(alpha, beta, gamma) = Rz(alpha) -> Ry(beta) -> Rz(gamma)
249-
auto qAlpha = createAxisQuaternion(op.getParameter(0), RotationAxis::Z, loc,
250-
rewriter);
251-
auto qBeta = createAxisQuaternion(op.getParameter(1), RotationAxis::Y, loc,
252-
rewriter);
253-
auto qGamma = createAxisQuaternion(op.getParameter(2), RotationAxis::Z, loc,
252+
// U(theta, phi, lambda) uses ZYZ decomposition: RZ(lambda) -> RY(theta) ->
253+
// RZ(phi)
254+
auto qTheta = createAxisQuaternion(op.getParameter(0), RotationAxis::Y, loc,
254255
rewriter);
255-
256-
// qGamma * qBeta * qAlpha (multiplication in reverse order!)
257-
auto temp = hamiltonProduct(qGamma, qBeta, op, rewriter);
258-
return hamiltonProduct(temp, qAlpha, op, rewriter);
256+
auto qPhi = createAxisQuaternion(op.getParameter(1), RotationAxis::Z, loc,
257+
rewriter);
258+
auto qLambda = createAxisQuaternion(op.getParameter(2), RotationAxis::Z,
259+
loc, rewriter);
260+
261+
// qPhi * qTheta * qLambda (multiplication in reverse order!)
262+
auto temp = hamiltonProduct(qPhi, qTheta, op, rewriter);
263+
return hamiltonProduct(temp, qLambda, op, rewriter);
259264
}
260265

261266
/**
262267
* @brief Converts a quaternion to a u-gate using ZYZ Euler angle extraction.
263268
*
264269
* For unit quaternion q = w + x*i + y*j + z*k, extracts u-gate parameters:
265-
* alpha = atan2(z, w) - atan2(-x, y)
270+
* alpha = atan2(z, w) + atan2(-x, y)
266271
* beta = acos(2 * (w^2 + z^2) - 1)
267-
* gamma = atan2(z, w) + atan2(-x, y)
272+
* gamma = atan2(z, w) - atan2(-x, y)
268273
*
269274
* Based on Bernardes & Viollet (2022), simplified for unit quaternions and
270275
* proper ZYZ Euler angles (Chapter 3.3):
@@ -287,6 +292,9 @@ struct MergeRotationGatesPattern final
287292
auto loc = op->getLoc();
288293

289294
auto floatType = op.getParameter(0).getType();
295+
// constant -1.0
296+
auto negOneAttr = rewriter.getFloatAttr(floatType, -1.0);
297+
auto negOne = rewriter.create<mlir::arith::ConstantOp>(loc, negOneAttr);
290298
// constant 0.0
291299
auto zeroAttr = rewriter.getFloatAttr(floatType, 0.0);
292300
auto zero = rewriter.create<mlir::arith::ConstantOp>(loc, zeroAttr);
@@ -305,12 +313,18 @@ struct MergeRotationGatesPattern final
305313

306314
// calculate angle beta (for y-rotation)
307315
// beta = acos(2 * (w^2 + z^2) - 1)
316+
// NOTE: the term (2 * (w^2 + z^2) - 1) is clamped to [-1, 1],
317+
// otherwise acos could produce NaN.
308318
auto ww = rewriter.create<mlir::arith::MulFOp>(loc, q.w, q.w);
309319
auto zz = rewriter.create<mlir::arith::MulFOp>(loc, q.z, q.z);
310320
auto bTemp1 = rewriter.create<mlir::arith::AddFOp>(loc, ww, zz);
311321
auto bTemp2 = rewriter.create<mlir::arith::MulFOp>(loc, two, bTemp1);
312322
auto bTemp3 = rewriter.create<mlir::arith::SubFOp>(loc, bTemp2, one);
313-
auto beta = rewriter.create<mlir::math::AcosOp>(loc, bTemp3);
323+
auto clampedLow =
324+
rewriter.create<mlir::arith::MaximumFOp>(loc, bTemp3, negOne);
325+
auto clamped =
326+
rewriter.create<mlir::arith::MinimumFOp>(loc, clampedLow, one);
327+
auto beta = rewriter.create<mlir::math::AcosOp>(loc, clamped);
314328

315329
// intermediates to check for gimbal lock (|beta| and |beta - PI|)
316330
auto absBeta = rewriter.create<mlir::math::AbsFOp>(loc, beta);
@@ -343,31 +357,31 @@ struct MergeRotationGatesPattern final
343357
rewriter.create<mlir::arith::MulFOp>(loc, two, thetaMinus);
344358

345359
// Safe Case (no gimbal lock):
346-
// alphaSafe = theta+ - theta-
347-
// gammaSafe = theta+ + theta-
360+
// alphaSafe = theta+ + theta-
361+
// gammaSafe = theta+ - theta-
348362
auto alphaSafe =
349-
rewriter.create<mlir::arith::SubFOp>(loc, thetaPlus, thetaMinus);
350-
auto gammaSafe =
351363
rewriter.create<mlir::arith::AddFOp>(loc, thetaPlus, thetaMinus);
364+
auto gammaSafe =
365+
rewriter.create<mlir::arith::SubFOp>(loc, thetaPlus, thetaMinus);
352366

353367
// Unsafe Case (gimbal lock):
354368
// when b = 0 then alpha = 2 * (atan2(z,w))
355-
// when b = PI then alpha = 2 * (atan2(-z, y))
369+
// when b = PI then alpha = 2 * (atan2(-x, y))
356370
// gamma is set to zero in both cases
357371
auto alphaUnsafe = rewriter.create<mlir::arith::SelectOp>(
358372
loc, safe1, twoThetaMinus, twoThetaPlus);
359373

360374
// TODO: could add some normalization here for alpha and gamma otherwise
361375
// they can be outside of [-PI, PI].
362376

363-
// choose correct alpha and gamma weather safe or not
377+
// choose correct alpha and gamma whether safe or not
364378
auto alpha = rewriter.create<mlir::arith::SelectOp>(loc, safe, alphaSafe,
365379
alphaUnsafe);
366380
auto gamma =
367381
rewriter.create<mlir::arith::SelectOp>(loc, safe, gammaSafe, zero);
368382

369-
return rewriter.create<UOp>(loc, op.getInputQubit(0), alpha.getResult(),
370-
beta.getResult(), gamma.getResult());
383+
return rewriter.create<UOp>(loc, op.getInputQubit(0), beta.getResult(),
384+
alpha.getResult(), gamma.getResult());
371385
}
372386

373387
/**

mlir/unittests/Compiler/test_compiler_pipeline.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3736,4 +3736,44 @@ TEST_F(CompilerPipelineTest, Bell) {
37363736
});
37373737
}
37383738

3739+
// ##################################################
3740+
// # Rotation Merge Pass Test
3741+
// ##################################################
3742+
3743+
/**
3744+
* @brief Test: Rotation merging pass is invoked during optimization stage
3745+
*
3746+
* @details
3747+
* The merged U gate parameters are computed via floating-point arithmetic
3748+
* that is not bit-identical across platforms, so we cannot use
3749+
* verifyAllStages with hardcoded expected values. Instead, we compare
3750+
* the optimization output with and without the pass enabled.
3751+
* Correctness of the pass is tested in a dedicated test.
3752+
*/
3753+
TEST_F(CompilerPipelineTest, RotationGateMergingPass) {
3754+
::qc::QuantumComputation comp;
3755+
comp.addQubitRegister(1, "q");
3756+
comp.rz(1.0, 0);
3757+
comp.rx(1.0, 0);
3758+
3759+
// Run with merging enabled
3760+
config.mergeRotationGates = true;
3761+
3762+
auto module = importQuantumCircuit(comp);
3763+
ASSERT_TRUE(module);
3764+
ASSERT_TRUE(runPipeline(module.get()).succeeded());
3765+
const auto withMerging = record.afterOptimization;
3766+
3767+
// Run with merging disabled
3768+
config.mergeRotationGates = false;
3769+
record = {};
3770+
3771+
module = importQuantumCircuit(comp);
3772+
ASSERT_TRUE(module);
3773+
ASSERT_TRUE(runPipeline(module.get()).succeeded());
3774+
const auto withoutMerging = record.afterOptimization;
3775+
3776+
// The outputs must differ, proving the pass ran and transformed the IR
3777+
EXPECT_NE(withMerging, withoutMerging);
3778+
}
37393779
} // namespace

mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ add_executable(mqt-core-mlir-dialect-qco-transforms-test test_qco_quaternion_mer
1010

1111
target_link_libraries(
1212
mqt-core-mlir-dialect-qco-transforms-test
13-
# TODO figure out correct dependencies
1413
PRIVATE GTest::gtest_main
1514
MLIRQCOProgramBuilder
1615
MLIRQCOTransforms
1716
MLIRIR
18-
MLIRPass # for PassManager
17+
MLIRPass
1918
MLIRSupport
2019
LLVMSupport)
2120

0 commit comments

Comments
 (0)