Skip to content

Commit 5ea1d04

Browse files
authored
Handle non-2D arrays in cirq.unitary (#7427)
Return False rather than raising IndexError for vector and scalar arrays.
1 parent a839837 commit 5ea1d04

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

cirq-core/cirq/linalg/predicates.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,12 @@ def is_unitary(matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8) ->
115115
Returns:
116116
Whether the matrix is unitary within the given tolerance.
117117
"""
118-
return matrix.shape[0] == matrix.shape[1] and np.allclose(
119-
matrix.dot(np.conj(matrix.T)), np.eye(matrix.shape[0]), rtol=rtol, atol=atol
118+
return (
119+
matrix.ndim == 2
120+
and matrix.shape[0] == matrix.shape[1]
121+
and np.allclose(
122+
matrix.dot(np.conj(matrix.T)), np.eye(matrix.shape[0]), rtol=rtol, atol=atol
123+
)
120124
)
121125

122126

cirq-core/cirq/linalg/predicates_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ def test_is_hermitian_tolerance():
103103

104104

105105
def test_is_unitary():
106+
assert not cirq.is_unitary(np.empty((0,)))
106107
assert cirq.is_unitary(np.empty((0, 0)))
107108
assert not cirq.is_unitary(np.empty((1, 0)))
108109
assert not cirq.is_unitary(np.empty((0, 1)))
110+
assert not cirq.is_unitary(np.empty((0, 0, 0)))
109111

112+
assert not cirq.is_unitary(np.array(1))
110113
assert cirq.is_unitary(np.array([[1]]))
111114
assert cirq.is_unitary(np.array([[-1]]))
112115
assert cirq.is_unitary(np.array([[1j]]))

0 commit comments

Comments
 (0)