diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 6e6a9f28f67..72d88ecc7f3 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -529,6 +529,46 @@ from cirq.value import ( + prngs + ABCMetaImplementAnyOneOf, + alternative, + big_endian_bits_to_int, + big_endian_digits_to_int, + big_endian_int_to_bits, + big_endian_int_to_digits, + canonicalize_half_turns, + chosen_angle_to_canonical_half_turns, + chosen_angle_to_half_turns, + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, + Condition, + Duration, + DURATION_LIKE, + KeyCondition, + LinearDict, + MEASUREMENT_KEY_SEPARATOR, + MeasurementKey, + MeasurementType, + PeriodicValue, + PRNG_OR_SEED_LIKE, + RANDOM_STATE_OR_SEED_LIKE, + state_vector_to_probabilities, + SympyCondition, + Timestamp, + TParamKey, + TParamVal, + TParamValComplex, + validate_probability, + value_equality, + KET_PLUS, + KET_MINUS, + KET_IMAG, + KET_MINUS_IMAG, + KET_ZERO, + KET_ONE, + PAULI_STATES, + ProductState, ABCMetaImplementAnyOneOf as ABCMetaImplementAnyOneOf, alternative as alternative, big_endian_bits_to_int as big_endian_bits_to_int, diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 031825ea943..770e163f125 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -156,6 +156,7 @@ 'QUANTUM_STATE_LIKE', 'QubitOrderOrList', 'RANDOM_STATE_OR_SEED_LIKE', + 'PRNG_OR_SEED_LIKE', 'STATE_VECTOR_LIKE', 'Sweepable', 'TParamKey', diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 6c385654e44..8211c97a01b 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -85,4 +85,9 @@ TParamValComplex as TParamValComplex, ) +from cirq.value.value_equality_attr import value_equality + + +from cirq.value.prng import parse_prng, PRNG_OR_SEED_LIKE + from cirq.value.value_equality_attr import value_equality as value_equality diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py new file mode 100644 index 00000000000..ade9f4cb6b0 --- /dev/null +++ b/cirq-core/cirq/value/prng.py @@ -0,0 +1,78 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import numpy as np +from typing import Union + +from cirq._doc import document + + +# Type for PRNG or seed-like input. +PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator] +document( + PRNG_OR_SEED_LIKE, + """A pseudorandom number generator or object that can be converted to one. + + Can be an instance of `np.random.Generator`, an integer seed, a `np.random.RandomState`, or None. + """, +) + + +def parse_prng(prng_or_seed: PRNG_OR_SEED_LIKE) -> np.random.Generator: + """Converts the input object into a `numpy.random.Generator`. + + - If `prng_or_seed` is already a `np.random.Generator`, it's returned directly. + - If `prng_or_seed` is `None`, returns a new `np.random.Generator` + instance (seeded unpredictably by NumPy). + - If `prng_or_seed` is an integer, returns `np.random.default_rng(prng_or_seed)`. + - If `prng_or_seed` is an instance of `np.random.RandomState`, returns a `np.random.Generator` initialized with the RandomState's bit generator or falls back on a random seed. + - Passing the `np.random` module itself is explicitly disallowed. + + Args: + prng_or_seed: The object to be used as or converted to a Generator. + + Returns: + The `numpy.random.Generator` object. + + Raises: + TypeError: If `prng_or_seed` is the `np.random` module or cannot be + converted to a `np.random.Generator`. + """ + if prng_or_seed is np.random: + raise TypeError( + "Passing the 'np.random' module is not supported. " + "Use None to get a default np.random.Generator instance." + ) + + if isinstance(prng_or_seed, np.random.Generator): + return prng_or_seed + + if prng_or_seed is None: + return np.random.default_rng() + + if isinstance(prng_or_seed, numbers.Integral): + return np.random.default_rng(int(prng_or_seed)) + + if isinstance(prng_or_seed, np.random.RandomState): + bit_gen = getattr(prng_or_seed, '_bit_generator', None) + if bit_gen is not None: + return np.random.default_rng(bit_gen) + seed_val = prng_or_seed.randint(2**31) + return np.random.default_rng(seed_val) + + raise TypeError( + f"Input {prng_or_seed} (type: {type(prng_or_seed).__name__}) cannot be converted " + f"to a {np.random.Generator.__name__}" + ) diff --git a/cirq-core/cirq/value/prng_test.py b/cirq-core/cirq/value/prng_test.py new file mode 100644 index 00000000000..33d99a1d3b4 --- /dev/null +++ b/cirq-core/cirq/value/prng_test.py @@ -0,0 +1,85 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np + +import cirq + + +def test_parse_prng_generator_passthrough(): + """Test that passing an existing Generator returns the same object.""" + rng = np.random.default_rng(12345) + assert cirq.value.parse_prng(rng) is rng + + +def test_parse_prng_none(): + """Test that passing None returns a new Generator instance.""" + rng1 = cirq.value.parse_prng(None) + rng2 = cirq.value.parse_prng(None) + assert rng1 is not rng2 + assert type(rng1) is np.random.Generator + assert type(rng2) is np.random.Generator + + +def test_parse_prng_int_seeding(): + """Test that integer seeds create predictable Generators.""" + rng_int = cirq.value.parse_prng(42) + rng_npint = cirq.value.parse_prng(np.int64(42)) + assert rng_int.random() == rng_npint.random() + + rng_different_seed = cirq.value.parse_prng(43) + rng_int = cirq.value.parse_prng(42) + assert rng_int.random() != rng_different_seed.random() + + +def test_parse_prng_module_disallowed(): + """Test that passing the np.random module raises TypeError.""" + with pytest.raises(TypeError, match="not supported"): + cirq.value.parse_prng(np.random) + + +def test_parse_prng_invalid_types(): + """Test that unsupported types raise TypeError.""" + + match = "cannot be converted" + with pytest.raises(TypeError, match=match): + cirq.value.parse_prng(1.0) + + with pytest.raises(TypeError, match=match): + cirq.value.parse_prng("not a seed") + + with pytest.raises(TypeError, match=match): + cirq.value.parse_prng([1, 2, 3]) + + with pytest.raises(TypeError, match=match): + cirq.value.parse_prng(object()) + + +def test_parse_prng_equality_tester_on_output(): + """Use EqualsTester to verify output consistency for valid inputs.""" + eq = cirq.testing.EqualsTester() + + eq.add_equality_group( + cirq.value.parse_prng(42).random(), + cirq.value.parse_prng(np.int32(42)).random(), + cirq.value.parse_prng(np.random.default_rng(42)).random(), + ) + + eq.add_equality_group( + cirq.value.parse_prng(np.random.RandomState(50)).random(), + cirq.value.parse_prng(np.random.RandomState(50)).random(), + ) + + eq.add_equality_group(cirq.value.parse_prng(None).random()) diff --git a/cirq-core/cirq/value/random_state.py b/cirq-core/cirq/value/random_state.py index fe60ef1db94..9f60a7edf68 100644 --- a/cirq-core/cirq/value/random_state.py +++ b/cirq-core/cirq/value/random_state.py @@ -30,11 +30,16 @@ If an integer, turns into a `np.random.RandomState` seeded with that integer. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. + If none of the above, it is used unmodified. In this case, it is assumed that the object implements whatever methods are required for the use case at hand. For example, it might be an existing instance of `np.random.RandomState` or a custom pseudorandom number generator implementation. + + Note: prefer to use cirq.PRNG_OR_SEED_LIKE. """, ) @@ -43,8 +48,9 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran """Interpret an object as a pseudorandom number generator. If `random_state` is None, returns the module `np.random`. - If `random_state` is an integer, returns - `np.random.RandomState(random_state)`. + If `random_state` is an integer, returns `np.random.RandomState(random_state)`. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. Otherwise, returns `random_state` unmodified. Args: @@ -58,5 +64,7 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran return cast(np.random.RandomState, np.random) elif isinstance(random_state, int): return np.random.RandomState(random_state) + elif isinstance(random_state, np.random.Generator): + return np.random.RandomState(random_state.bit_generator) else: return cast(np.random.RandomState, random_state) diff --git a/cirq-core/cirq/value/random_state_test.py b/cirq-core/cirq/value/random_state_test.py index fd2f6745d23..a3e0611a0af 100644 --- a/cirq-core/cirq/value/random_state_test.py +++ b/cirq-core/cirq/value/random_state_test.py @@ -44,3 +44,5 @@ def rand(prng): vals = [prng.rand() for prng in prngs1] eq = cirq.testing.EqualsTester() eq.add_equality_group(*vals) + + eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand())