Skip to content

Commit 4c9b173

Browse files
authored
Modify apply_unitary and unitary to take in numpy unitary matrix. (#7419)
Addressing issue: #7050 TESTED=Ran following tests for files affected. ``` check/pytest-changed-files check/pytest check/format-incremental --apply check/pylint-changed-files ```
1 parent 5ea1d04 commit 4c9b173

12 files changed

+46
-49
lines changed

cirq-core/cirq/protocols/apply_mixture_protocol.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
from typing_extensions import Protocol
2424

25-
from cirq import linalg
2625
from cirq._doc import doc_private
2726
from cirq.protocols import qid_shape_protocol
2827
from cirq.protocols.apply_unitary_protocol import apply_unitary, ApplyUnitaryArgs
@@ -332,32 +331,6 @@ def _apply_unitary_strat(
332331
return right_result
333332

334333

335-
def _apply_unitary_from_matrix_strat(
336-
val: np.ndarray, args: ApplyMixtureArgs, is_density_matrix: bool
337-
) -> np.ndarray | None:
338-
"""Used to enact mixture tuples that are given as (probability, np.ndarray)
339-
340-
If `val` does not support `apply_unitary` returns None.
341-
"""
342-
qid_shape = tuple(args.target_tensor.shape[i] for i in args.left_axes)
343-
matrix_tensor = np.reshape(val.astype(args.target_tensor.dtype), qid_shape * 2)
344-
linalg.targeted_left_multiply(
345-
matrix_tensor, args.target_tensor, args.left_axes, out=args.auxiliary_buffer0
346-
)
347-
348-
if not is_density_matrix:
349-
return args.auxiliary_buffer0
350-
# No need to transpose as we are acting on the tensor
351-
# representation of matrix, so transpose is done for us.
352-
linalg.targeted_left_multiply(
353-
np.conjugate(matrix_tensor),
354-
args.auxiliary_buffer0,
355-
cast(tuple[int], args.right_axes),
356-
out=args.target_tensor,
357-
)
358-
return args.target_tensor
359-
360-
361334
def _apply_mixture_from_mixture_strat(
362335
val: Any, args: ApplyMixtureArgs, is_density_matrix: bool
363336
) -> np.ndarray | None:
@@ -373,8 +346,6 @@ def _apply_mixture_from_mixture_strat(
373346
for prob, op in prob_mix:
374347
np.copyto(dst=args.target_tensor, src=args.auxiliary_buffer1)
375348
right_result = _apply_unitary_strat(op, args, is_density_matrix)
376-
if right_result is None:
377-
right_result = _apply_unitary_from_matrix_strat(op, args, is_density_matrix)
378349

379350
args.out_buffer += prob * right_result
380351

cirq-core/cirq/protocols/apply_unitary_protocol.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -469,15 +469,20 @@ def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: App
469469
def _strat_apply_unitary_from_unitary(
470470
unitary_value: Any, args: ApplyUnitaryArgs
471471
) -> np.ndarray | None:
472-
# Check for magic method.
473-
method = getattr(unitary_value, '_unitary_', None)
474-
if method is None:
475-
return NotImplemented
476-
477-
# Attempt to get the unitary matrix.
478-
matrix = method()
479-
if matrix is NotImplemented or matrix is None:
480-
return matrix
472+
if isinstance(unitary_value, np.ndarray):
473+
matrix = unitary_value
474+
if not linalg.is_unitary(matrix):
475+
return None
476+
else:
477+
# Check for magic method.
478+
method = getattr(unitary_value, '_unitary_', None)
479+
if method is None:
480+
return NotImplemented
481+
482+
# Attempt to get the unitary matrix.
483+
matrix = method()
484+
if matrix is NotImplemented or matrix is None:
485+
return matrix
481486

482487
return _apply_unitary_from_matrix(matrix, unitary_value, args)
483488

cirq-core/cirq/protocols/apply_unitary_protocol_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
5656
args.target_tensor[one] *= -1
5757
return args.target_tensor
5858

59-
fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented()]
59+
fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented(), m * 2]
6060
passes = [
6161
HasUnitary(),
6262
HasApplyReturnsNotImplementedButHasUnitary(),
6363
HasApplyOutputInBuffer(),
6464
HasApplyMutateInline(),
65+
m,
6566
]
6667

6768
def make_input():

cirq-core/cirq/protocols/has_unitary_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020
from typing_extensions import Protocol
2121

22-
from cirq import qis
22+
from cirq import linalg, qis
2323
from cirq._doc import doc_private
2424
from cirq.protocols import qid_shape_protocol
2525
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs
@@ -112,6 +112,8 @@ def has_unitary(val: Any, *, allow_decompose: bool = True) -> bool:
112112

113113
def _strat_has_unitary_from_has_unitary(val: Any) -> bool | None:
114114
"""Attempts to infer a value's unitary-ness via its _has_unitary_ method."""
115+
if isinstance(val, np.ndarray):
116+
return linalg.is_unitary(val)
115117
if hasattr(val, '_has_unitary_'):
116118
result = val._has_unitary_()
117119
if result is NotImplemented:

cirq-core/cirq/protocols/has_unitary_protocol_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,13 @@ class Yes:
6161
def _unitary_(self):
6262
return np.array([[1]])
6363

64+
m = np.diag([1, -1])
6465
assert not cirq.has_unitary(No1())
6566
assert not cirq.has_unitary(No2())
67+
assert not cirq.has_unitary(m * 2)
6668
assert cirq.has_unitary(Yes())
6769
assert cirq.has_unitary(Yes(), allow_decompose=False)
70+
assert cirq.has_unitary(m)
6871

6972

7073
def test_via_apply_unitary() -> None:

cirq-core/cirq/protocols/kraus_protocol.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ def kraus(
148148
mixture_getter = getattr(val, '_mixture_', None)
149149
mixture_result = NotImplemented if mixture_getter is None else mixture_getter()
150150
if mixture_result is not NotImplemented and mixture_result is not None:
151-
return tuple(
152-
np.sqrt(p) * (u if isinstance(u, np.ndarray) else unitary(u)) for p, u in mixture_result
153-
)
151+
return tuple(np.sqrt(p) * unitary(u) for p, u in mixture_result)
154152

155153
unitary_getter = getattr(val, '_unitary_', None)
156154
unitary_result = NotImplemented if unitary_getter is None else unitary_getter()

cirq-core/cirq/protocols/mixture_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def mixture(
9494
mixture_getter = getattr(val, '_mixture_', None)
9595
result = NotImplemented if mixture_getter is None else mixture_getter()
9696
if result is not NotImplemented and result is not None:
97-
return tuple((p, u if isinstance(u, np.ndarray) else unitary(u)) for p, u in result)
97+
return tuple((p, unitary(u)) for p, u in result)
9898

9999
unitary_getter = getattr(val, '_unitary_', None)
100100
result = NotImplemented if unitary_getter is None else unitary_getter()

cirq-core/cirq/protocols/mixture_protocol_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import cirq
2121

22-
a = np.array([1])
23-
b = np.array([1j])
22+
a = np.array([[1]])
23+
b = np.array([[1j]])
2424

2525

2626
class NoMethod:

cirq-core/cirq/protocols/pauli_expansion_protocol_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _unitary_(self) -> np.ndarray:
5454

5555

5656
@pytest.mark.parametrize(
57-
'val', (NoMethod(), ReturnsNotImplemented(), HasQuditUnitary(), 123, np.eye(2), object(), cirq)
57+
'val', (NoMethod(), ReturnsNotImplemented(), HasQuditUnitary(), 123, object(), cirq)
5858
)
5959
def test_raises_no_pauli_expansion(val) -> None:
6060
assert cirq.pauli_expansion(val, default=None) is None

cirq-core/cirq/protocols/unitary_protocol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
from typing_extensions import Protocol
2222

23+
from cirq import linalg
2324
from cirq._doc import doc_private
2425
from cirq.protocols import qid_shape_protocol
2526
from cirq.protocols.apply_unitary_protocol import apply_unitaries, ApplyUnitaryArgs
@@ -84,6 +85,7 @@ def unitary(
8485
8586
The matrix is determined by any one of the following techniques:
8687
88+
- If the value is a numpy array, it is returned directly.
8789
- The value has a `_unitary_` method that returns something besides None or
8890
NotImplemented. The matrix is whatever the method returned.
8991
- The value has a `_decompose_` method that returns a list of operations,
@@ -110,7 +112,13 @@ def unitary(
110112
Raises:
111113
TypeError: `val` doesn't have a unitary effect and no default value was
112114
specified.
115+
ValueError: `val` is a numpy array that is not unitary.
113116
"""
117+
if isinstance(val, np.ndarray):
118+
if not linalg.is_unitary(val):
119+
raise ValueError("The provided numpy array is not unitary.")
120+
return val
121+
114122
strats = [
115123
_strat_unitary_from_unitary,
116124
_strat_unitary_from_apply_unitary,

cirq-core/cirq/protocols/unitary_protocol_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,15 @@ def test_unitary():
161161
_ = cirq.unitary(ReturnsNotImplemented())
162162
assert cirq.unitary(ReturnsMatrix()) is m1
163163

164+
# Test that numpy arrays are handled directly
165+
test_matrix = np.array([[1, 0], [0, 1]])
166+
assert cirq.unitary(test_matrix, NotImplemented) is test_matrix
167+
168+
# Test that non-unitary numpy arrays raise ValueError
169+
non_unitary_matrix = np.array([[1, 1], [0, 1]])
170+
with pytest.raises(ValueError, match="The provided numpy array is not unitary"):
171+
_ = cirq.unitary(non_unitary_matrix)
172+
164173
assert cirq.unitary(NoMethod(), None) is None
165174
assert cirq.unitary(ReturnsNotImplemented(), None) is None
166175
assert cirq.unitary(ReturnsMatrix(), None) is m1

cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def test_known_two_qubit_op_decomposition(op, theta_range):
8686
cirq.FSimGate(0.25, 0.85).on(*_QUBITS),
8787
cirq.XX(*_QUBITS),
8888
cirq.YY(*_QUBITS),
89-
*[cirq.testing.random_unitary(4, random_state=1234) for _ in range(10)],
89+
cirq.MatrixGate(cirq.testing.random_unitary(4)).on(*_QUBITS),
9090
],
9191
)
9292
def test_unknown_two_qubit_op_decomposition(op):
9393
assert cg.known_2q_op_to_sycamore_operations(op) is None
94-
if cirq.has_unitary(op) and cirq.num_qubits(op) == 2:
94+
if not cirq.is_parameterized(op) and cirq.num_qubits(op) == 2:
9595
matrix_2q_circuit = cirq.Circuit(
9696
cg.two_qubit_matrix_to_sycamore_operations(_QUBITS[0], _QUBITS[1], cirq.unitary(op))
9797
)

0 commit comments

Comments
 (0)