Skip to content

Refactor Controlled Gate Tests for Non-Standard Control Values #7321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,6 +18,8 @@
import sympy

import cirq
import cirq.protocols
from cirq.ops import ControlledGate
from cirq.protocols.act_on_protocol_test import ExampleSimulationState

H = np.array([[1, 1], [1, -1]]) * np.sqrt(0.5)
Expand Down Expand Up @@ -227,6 +230,181 @@ def test_global_phase_controlled_gate(gate, matrix):
np.testing.assert_equal(cirq.unitary(gate.controlled()), matrix)


# --- Tests for non-standard control values --- START ---

def test_controlled_x_zero_control_type():
"""Tests that X.controlled(cv=[0]) returns a ControlledGate."""
gate = cirq.X.controlled(control_values=[0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CXPowGate) # Explicitly check it's not the specialized type
assert gate.sub_gate == cirq.X
assert gate.control_values == cirq.SumOfProducts([[0]])

def test_controlled_x_zero_control_unitary():
"""Tests the unitary of X.controlled(cv=[0])."""
gate = cirq.X.controlled(control_values=[0])
# Expected: Apply X if control is 0, else Identity
# |00> -> |01>
# |01> -> |00>
# |10> -> |10>
# |11> -> |11>
expected_matrix = np.array([
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=np.complex128)
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


def test_controlled_z_zero_control_type():
"""Tests that Z.controlled(cv=[0]) returns a ControlledGate."""
gate = cirq.Z.controlled(control_values=[0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CZPowGate)
assert gate.sub_gate == cirq.Z
assert gate.control_values == cirq.SumOfProducts([[0]])

def test_controlled_z_zero_control_unitary():
"""Tests the unitary of Z.controlled(cv=[0])."""
gate = cirq.Z.controlled(control_values=[0])
# Expected: Apply Z if control is 0, else Identity
# |00> -> |00>
# |01> -> -|01>
# |10> -> |10>
# |11> -> |11>
expected_matrix = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=np.complex128)
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


def test_controlled_cx_zero_control_type():
"""Tests that CX.controlled(cv=[0]) returns a ControlledGate(CX)."""
gate = cirq.CX.controlled(control_values=[0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CCXPowGate)
assert gate.sub_gate == cirq.CX
# Check control values, expect [(0,)] for the outer control
assert gate.control_values == cirq.SumOfProducts([[0]])

def test_controlled_cx_zero_control_unitary():
"""Tests the unitary of CX.controlled(cv=[0])."""
gate = cirq.CX.controlled(control_values=[0])
# Expected: Apply CX to q1, q2 if q0 is 0, else Identity on q1, q2
# Basis: |q0 q1 q2>
# |000> -> |000>
# |001> -> |001>
# |010> -> |011>
# |011> -> |010>
# |100> -> |100>
# |101> -> |101>
# |110> -> |110>
# |111> -> |111>
expected_matrix = np.array([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]
], dtype=np.complex128)
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


def test_controlled_cz_zero_control_type():
"""Tests that CZ.controlled(cv=[0]) returns a ControlledGate(CZ)."""
gate = cirq.CZ.controlled(control_values=[0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CCZPowGate)
assert gate.sub_gate == cirq.CZ
# Check control values, expect [(0,)] for the outer control
assert gate.control_values == cirq.SumOfProducts([[0]])

def test_controlled_cz_zero_control_unitary():
"""Tests the unitary of CZ.controlled(cv=[0])."""
gate = cirq.CZ.controlled(control_values=[0])
# Expected: Apply CZ to q1, q2 if q0 is 0, else Identity on q1, q2
# Basis: |q0 q1 q2>
# |000> -> |000>
# |001> -> |001>
# |010> -> |010>
# |011> -> -|011>
# |100> -> |100>
# |101> -> |101>
# |110> -> |110>
# |111> -> |111>
expected_matrix = np.array([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0,-1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]
], dtype=np.complex128)
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


def test_controlled_x_mixed_controls_type():
"""Tests that X.controlled(cv=[1, 0]) returns a ControlledGate."""
gate = cirq.X.controlled(control_values=[1, 0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CCXPowGate) # No specialization expected
assert gate.sub_gate == cirq.X
assert gate.control_values == cirq.SumOfProducts([[1, 0]])

def test_controlled_x_mixed_controls_unitary():
"""Tests the unitary of X.controlled(cv=[1, 0])."""
gate = cirq.X.controlled(control_values=[1, 0])
# Expected: Apply X to q2 if q0=1 and q1=0
# Basis: |q0 q1 q2>
# |100> -> |101>
# |101> -> |100>
# Other states unchanged
expected_matrix = np.identity(8, dtype=np.complex128)
expected_matrix[4, 4] = 0
expected_matrix[5, 5] = 0
expected_matrix[4, 5] = 1
expected_matrix[5, 4] = 1
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


def test_controlled_z_mixed_controls_type():
"""Tests that Z.controlled(cv=[1, 0]) returns a ControlledGate."""
gate = cirq.Z.controlled(control_values=[1, 0])
assert isinstance(gate, ControlledGate)
assert not isinstance(gate, cirq.CCZPowGate) # No specialization expected
assert gate.sub_gate == cirq.Z
assert gate.control_values == cirq.SumOfProducts([[1, 0]])

def test_controlled_z_mixed_controls_unitary():
"""Tests the unitary of Z.controlled(cv=[1, 0])."""
gate = cirq.Z.controlled(control_values=[1, 0])
# Expected: Apply Z to q2 if q0=1 and q1=0
# Basis: |q0 q1 q2>
# |100> -> |100>
# |101> -> -|101>
# Other states unchanged
expected_matrix = np.identity(8, dtype=np.complex128)
expected_matrix[5, 5] = -1
actual_matrix = cirq.unitary(gate)
assert np.allclose(actual_matrix, expected_matrix)


# --- Tests for non-standard control values --- END ---

def test_rot_gates_eq():
eq = cirq.testing.EqualsTester()
gates = [
Expand Down