1111#include " mlir/Dialect/QCO/IR/QCODialect.h"
1212#include " mlir/Dialect/QCO/Transforms/Passes.h"
1313
14- #include < algorithm>
15- #include < array>
1614#include < cassert>
1715#include < cmath>
1816#include < cstdint>
17+ #include < llvm/ADT/TypeSwitch.h>
1918#include < llvm/Support/ErrorHandling.h>
2019#include < mlir/Dialect/Arith/IR/Arith.h>
2120#include < mlir/Dialect/Math/IR/Math.h>
2625#include < mlir/Support/LLVM.h>
2726#include < mlir/Transforms/GreedyPatternRewriteDriver.h>
2827#include < numbers>
29- #include < string_view >
28+ #include < optional >
3029#include < utility>
3130
3231namespace mlir ::qco {
@@ -52,17 +51,14 @@ struct MergeRotationGatesPattern final
5251
5352 enum class RotationAxis : std::uint8_t { X, Y, Z };
5453
55- static constexpr std::array<std::string_view, 4 > MERGEABLE_GATES = {
56- " u" , " rx" , " ry" , " rz" };
57-
5854 /* *
5955 * @brief Checks if an operation is a mergeable rotation gate (rx, ry, rz, u).
6056 *
61- * @param name Name of the operation to check
57+ * @param op The operation to check
6258 * @return True if mergeable, false otherwise
6359 */
64- static bool isMergeable (std::string_view name ) {
65- return std::ranges::find (MERGEABLE_GATES, name) != MERGEABLE_GATES. end ( );
60+ static bool isMergeable (mlir::Operation* op ) {
61+ return mlir::isa<RXOp, RYOp, RZOp, UOp>(op );
6662 }
6763
6864 /* *
@@ -78,13 +74,28 @@ struct MergeRotationGatesPattern final
7874 */
7975 [[nodiscard]] static bool areQuaternionMergeable (mlir::Operation& a,
8076 mlir::Operation& b) {
81- const auto aName = a.getName ().stripDialect ().str ();
82- const auto bName = b.getName ().stripDialect ().str ();
83-
84- if (!(isMergeable (aName) && isMergeable (bName))) {
77+ if (!isMergeable (&a) || !isMergeable (&b)) {
8578 return false ;
8679 }
87- return (aName != bName) || (aName == " u" && bName == " u" );
80+
81+ // Different gate types OR both are U gates
82+ return (a.getName () != b.getName ()) ||
83+ (mlir::isa<UOp>(a) && mlir::isa<UOp>(b));
84+ }
85+
86+ /* *
87+ * @brief Returns the rotation axis for a single-axis rotation gate.
88+ *
89+ * @param op The operation to query
90+ * @return The rotation axis, or std::nullopt if the operation is not a
91+ * single-axis rotation gate (RX, RY, RZ)
92+ */
93+ static std::optional<RotationAxis> getRotationAxis (mlir::Operation* op) {
94+ return llvm::TypeSwitch<mlir::Operation*, std::optional<RotationAxis>>(op)
95+ .Case <RXOp>([](auto ) { return RotationAxis::X; })
96+ .Case <RYOp>([](auto ) { return RotationAxis::Y; })
97+ .Case <RZOp>([](auto ) { return RotationAxis::Z; })
98+ .Default ([](auto ) { return std::nullopt ; });
8899 }
89100
90101 /* *
@@ -130,6 +141,8 @@ struct MergeRotationGatesPattern final
130141 case RotationAxis::Z:
131142 return {.w = cos, .x = zero, .y = zero, .z = sin};
132143 } // NOLINT(bugprone-branch-clone): false positive, branches differ
144+
145+ llvm_unreachable (" Invalid rotation axis" );
133146 }
134147
135148 /* *
@@ -142,24 +155,15 @@ struct MergeRotationGatesPattern final
142155 */
143156 static Quaternion quaternionFromRotation (UnitaryOpInterface op,
144157 mlir::PatternRewriter& rewriter) {
145- auto const type = op->getName ().stripDialect ().str ();
146-
147- if (type == " u" ) {
158+ if (mlir::isa<UOp>(op)) {
148159 return quaternionFromUGate (op, rewriter);
149160 }
150161
151- auto loc = op->getLoc ();
152- auto angle = op.getParameter (0 );
153-
154- if (type == " rx" ) {
155- return createAxisQuaternion (angle, RotationAxis::X, loc, rewriter);
156- }
157- if (type == " ry" ) {
158- return createAxisQuaternion (angle, RotationAxis::Y, loc, rewriter);
159- }
160- if (type == " rz" ) {
161- return createAxisQuaternion (angle, RotationAxis::Z, loc, rewriter);
162+ if (auto axis = getRotationAxis (op.getOperation ())) {
163+ return createAxisQuaternion (op.getParameter (0 ), *axis, op->getLoc (),
164+ rewriter);
162165 }
166+
163167 llvm_unreachable (" Unsupported operation type" );
164168 }
165169
@@ -227,14 +231,14 @@ struct MergeRotationGatesPattern final
227231 /* *
228232 * @brief Converts a u-gate to quaternion representation.
229233 *
230- * U(alpha, beta, gamma ) uses ZYZ decomposition: RZ(alpha ) -> RY(beta ) ->
231- * RZ(gamma ).
234+ * U(theta, phi, lambda ) uses ZYZ decomposition: RZ(lambda ) -> RY(theta ) ->
235+ * RZ(phi ).
232236 *
233237 * When composing rotations, quaternion multiplication follows matrix
234238 * multiplication order (right-to-left), which is the reverse of the
235239 * application sequence:
236- * Sequential application: RZ(alpha ), then RY(beta ), then RZ(gamma )
237- * Quaternion product: Qgamma * Qbeta * Qalpha
240+ * Sequential application: RZ(lambda ), then RY(theta ), then RZ(phi )
241+ * Quaternion product: qPhi * qTheta * qLambda
238242 *
239243 * @param op The u-gate operation to convert
240244 * @param rewriter Pattern rewriter for creating new operations
@@ -245,26 +249,27 @@ struct MergeRotationGatesPattern final
245249 auto loc = op->getLoc ();
246250
247251 // U gate uses ZYZ decomposition:
248- // U(alpha, beta, gamma) = Rz(alpha) -> Ry(beta) -> Rz(gamma)
249- auto qAlpha = createAxisQuaternion (op.getParameter (0 ), RotationAxis::Z, loc,
250- rewriter);
251- auto qBeta = createAxisQuaternion (op.getParameter (1 ), RotationAxis::Y, loc,
252- rewriter);
253- auto qGamma = createAxisQuaternion (op.getParameter (2 ), RotationAxis::Z, loc,
252+ // U(theta, phi, lambda) uses ZYZ decomposition: RZ(lambda) -> RY(theta) ->
253+ // RZ(phi)
254+ auto qTheta = createAxisQuaternion (op.getParameter (0 ), RotationAxis::Y, loc,
254255 rewriter);
255-
256- // qGamma * qBeta * qAlpha (multiplication in reverse order!)
257- auto temp = hamiltonProduct (qGamma, qBeta, op, rewriter);
258- return hamiltonProduct (temp, qAlpha, op, rewriter);
256+ auto qPhi = createAxisQuaternion (op.getParameter (1 ), RotationAxis::Z, loc,
257+ rewriter);
258+ auto qLambda = createAxisQuaternion (op.getParameter (2 ), RotationAxis::Z,
259+ loc, rewriter);
260+
261+ // qPhi * qTheta * qLambda (multiplication in reverse order!)
262+ auto temp = hamiltonProduct (qPhi, qTheta, op, rewriter);
263+ return hamiltonProduct (temp, qLambda, op, rewriter);
259264 }
260265
261266 /* *
262267 * @brief Converts a quaternion to a u-gate using ZYZ Euler angle extraction.
263268 *
264269 * For unit quaternion q = w + x*i + y*j + z*k, extracts u-gate parameters:
265- * alpha = atan2(z, w) - atan2(-x, y)
270+ * alpha = atan2(z, w) + atan2(-x, y)
266271 * beta = acos(2 * (w^2 + z^2) - 1)
267- * gamma = atan2(z, w) + atan2(-x, y)
272+ * gamma = atan2(z, w) - atan2(-x, y)
268273 *
269274 * Based on Bernardes & Viollet (2022), simplified for unit quaternions and
270275 * proper ZYZ Euler angles (Chapter 3.3):
@@ -287,6 +292,9 @@ struct MergeRotationGatesPattern final
287292 auto loc = op->getLoc ();
288293
289294 auto floatType = op.getParameter (0 ).getType ();
295+ // constant -1.0
296+ auto negOneAttr = rewriter.getFloatAttr (floatType, -1.0 );
297+ auto negOne = rewriter.create <mlir::arith::ConstantOp>(loc, negOneAttr);
290298 // constant 0.0
291299 auto zeroAttr = rewriter.getFloatAttr (floatType, 0.0 );
292300 auto zero = rewriter.create <mlir::arith::ConstantOp>(loc, zeroAttr);
@@ -305,12 +313,18 @@ struct MergeRotationGatesPattern final
305313
306314 // calculate angle beta (for y-rotation)
307315 // beta = acos(2 * (w^2 + z^2) - 1)
316+ // NOTE: the term (2 * (w^2 + z^2) - 1) is clamped to [-1, 1],
317+ // otherwise acos could produce NaN.
308318 auto ww = rewriter.create <mlir::arith::MulFOp>(loc, q.w , q.w );
309319 auto zz = rewriter.create <mlir::arith::MulFOp>(loc, q.z , q.z );
310320 auto bTemp1 = rewriter.create <mlir::arith::AddFOp>(loc, ww, zz);
311321 auto bTemp2 = rewriter.create <mlir::arith::MulFOp>(loc, two, bTemp1);
312322 auto bTemp3 = rewriter.create <mlir::arith::SubFOp>(loc, bTemp2, one);
313- auto beta = rewriter.create <mlir::math::AcosOp>(loc, bTemp3);
323+ auto clampedLow =
324+ rewriter.create <mlir::arith::MaximumFOp>(loc, bTemp3, negOne);
325+ auto clamped =
326+ rewriter.create <mlir::arith::MinimumFOp>(loc, clampedLow, one);
327+ auto beta = rewriter.create <mlir::math::AcosOp>(loc, clamped);
314328
315329 // intermediates to check for gimbal lock (|beta| and |beta - PI|)
316330 auto absBeta = rewriter.create <mlir::math::AbsFOp>(loc, beta);
@@ -343,31 +357,31 @@ struct MergeRotationGatesPattern final
343357 rewriter.create <mlir::arith::MulFOp>(loc, two, thetaMinus);
344358
345359 // Safe Case (no gimbal lock):
346- // alphaSafe = theta+ - theta-
347- // gammaSafe = theta+ + theta-
360+ // alphaSafe = theta+ + theta-
361+ // gammaSafe = theta+ - theta-
348362 auto alphaSafe =
349- rewriter.create <mlir::arith::SubFOp>(loc, thetaPlus, thetaMinus);
350- auto gammaSafe =
351363 rewriter.create <mlir::arith::AddFOp>(loc, thetaPlus, thetaMinus);
364+ auto gammaSafe =
365+ rewriter.create <mlir::arith::SubFOp>(loc, thetaPlus, thetaMinus);
352366
353367 // Unsafe Case (gimbal lock):
354368 // when b = 0 then alpha = 2 * (atan2(z,w))
355- // when b = PI then alpha = 2 * (atan2(-z , y))
369+ // when b = PI then alpha = 2 * (atan2(-x , y))
356370 // gamma is set to zero in both cases
357371 auto alphaUnsafe = rewriter.create <mlir::arith::SelectOp>(
358372 loc, safe1, twoThetaMinus, twoThetaPlus);
359373
360374 // TODO: could add some normalization here for alpha and gamma otherwise
361375 // they can be outside of [-PI, PI].
362376
363- // choose correct alpha and gamma weather safe or not
377+ // choose correct alpha and gamma whether safe or not
364378 auto alpha = rewriter.create <mlir::arith::SelectOp>(loc, safe, alphaSafe,
365379 alphaUnsafe);
366380 auto gamma =
367381 rewriter.create <mlir::arith::SelectOp>(loc, safe, gammaSafe, zero);
368382
369- return rewriter.create <UOp>(loc, op.getInputQubit (0 ), alpha .getResult (),
370- beta .getResult (), gamma.getResult ());
383+ return rewriter.create <UOp>(loc, op.getInputQubit (0 ), beta .getResult (),
384+ alpha .getResult (), gamma.getResult ());
371385 }
372386
373387 /* *
0 commit comments