diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 0a35b2b5687..c70c328a26c 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -538,6 +538,7 @@ MeasurementKey, MeasurementType, PeriodicValue, + PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, SympyCondition, diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 22ae86051e0..3a127812fc2 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -154,6 +154,7 @@ 'QUANTUM_STATE_LIKE', 'QubitOrderOrList', 'RANDOM_STATE_OR_SEED_LIKE', + 'PRNG_OR_SEED_LIKE', 'STATE_VECTOR_LIKE', 'Sweepable', 'TParamKey', diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index a810bf108e9..4847b1ff5a5 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -65,3 +65,6 @@ from cirq.value.type_alias import TParamKey, TParamVal, TParamValComplex from cirq.value.value_equality_attr import value_equality + + +from cirq.value.prng import parse_prng, PRNG_OR_SEED_LIKE diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py new file mode 100644 index 00000000000..86440ccee3d --- /dev/null +++ b/cirq-core/cirq/value/prng.py @@ -0,0 +1,66 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import numbers +import numpy as np + +from cirq._doc import document +from cirq.value.random_state import RANDOM_STATE_OR_SEED_LIKE + +PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator] + +document( + PRNG_OR_SEED_LIKE, + """A pseudorandom number generator or object that can be converted to one. + + If is an integer or None, turns into a `np.random.Generator` seeded with that value. + If is an instance of `np.random.Generator` or a subclass of it, return as is. + If is an instance of `np.random.RandomState` or has a `randint` method, returns + `np.random.default_rng(rs.randint(2**31))` + """, +) + + +def parse_prng( + prng_or_seed: Union[PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE] +) -> np.random.Generator: + """Interpret an object as a pseudorandom number generator. + + If `prng_or_seed` is an `np.random.Generator`, return it unmodified. + If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. + If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method, + returns `np.random.default_rng(prng_or_seed.randint(2**31))`. + + Args: + prng_or_seed: The object to be used as or converted to a pseudorandom + number generator. + + Returns: + The pseudorandom number generator object. + + Raises: + TypeError: If `prng_or_seed` is can't be converted to an np.random.Generator. + """ + if isinstance(prng_or_seed, np.random.Generator): + return prng_or_seed + if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral): + return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed)) + if isinstance(prng_or_seed, np.random.RandomState): + return np.random.default_rng(prng_or_seed.randint(2**31)) + randint = getattr(prng_or_seed, "randint", None) + if randint is not None: + return np.random.default_rng(randint(2**31)) + raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator") diff --git a/cirq-core/cirq/value/prng_test.py b/cirq-core/cirq/value/prng_test.py new file mode 100644 index 00000000000..9ca84ae46bf --- /dev/null +++ b/cirq-core/cirq/value/prng_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import pytest +import numpy as np + +import cirq + + +def _sample(prng): + return tuple(prng.random(10)) + + +def test_parse_rng() -> None: + eq = cirq.testing.EqualsTester() + + # An `np.random.Generator` or a seed. + group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)] + group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs] + eq.add_equality_group(*[_sample(g) for g in group]) + + # A None seed. + prng = cirq.value.parse_prng(None) + eq.add_equality_group(_sample(prng)) + + # RandomState PRNG. + prng = cirq.value.parse_prng(np.random.RandomState(42)) + eq.add_equality_group(_sample(prng)) + + # np.random module + prng = cirq.value.parse_prng(np.random) + eq.add_equality_group(_sample(prng)) + + with pytest.raises(TypeError): + _ = cirq.value.parse_prng(1.0) diff --git a/cirq-core/cirq/value/random_state.py b/cirq-core/cirq/value/random_state.py index ab884bd4705..771b80634f0 100644 --- a/cirq-core/cirq/value/random_state.py +++ b/cirq-core/cirq/value/random_state.py @@ -28,11 +28,16 @@ If an integer, turns into a `np.random.RandomState` seeded with that integer. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. + If none of the above, it is used unmodified. In this case, it is assumed that the object implements whatever methods are required for the use case at hand. For example, it might be an existing instance of `np.random.RandomState` or a custom pseudorandom number generator implementation. + + Note: prefer to use cirq.PRNG_OR_SEED_LIKE. """, ) @@ -41,8 +46,9 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran """Interpret an object as a pseudorandom number generator. If `random_state` is None, returns the module `np.random`. - If `random_state` is an integer, returns - `np.random.RandomState(random_state)`. + If `random_state` is an integer, returns `np.random.RandomState(random_state)`. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. Otherwise, returns `random_state` unmodified. Args: @@ -56,5 +62,7 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran return cast(np.random.RandomState, np.random) elif isinstance(random_state, int): return np.random.RandomState(random_state) + elif isinstance(random_state, np.random.Generator): + return np.random.RandomState(random_state.bit_generator) else: return cast(np.random.RandomState, random_state) diff --git a/cirq-core/cirq/value/random_state_test.py b/cirq-core/cirq/value/random_state_test.py index 694965ea0a9..3f74ea0932e 100644 --- a/cirq-core/cirq/value/random_state_test.py +++ b/cirq-core/cirq/value/random_state_test.py @@ -42,3 +42,5 @@ def rand(prng): vals = [prng.rand() for prng in prngs] eq = cirq.testing.EqualsTester() eq.add_equality_group(*vals) + + eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand())