From 834ee27dfc99cc7394e19a0c7c5d8a60c4d4bf5b Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 22 Apr 2024 11:46:58 -0700 Subject: [PATCH 1/7] Add support for np.random.Generator --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/value/__init__.py | 3 ++ cirq-core/cirq/value/prng.py | 80 ++++++++++++++++++++++++++++ cirq-core/cirq/value/prng_test.py | 48 +++++++++++++++++ cirq-core/cirq/value/random_state.py | 2 + 5 files changed, 134 insertions(+) create mode 100644 cirq-core/cirq/value/prng.py create mode 100644 cirq-core/cirq/value/prng_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 2dc2034600a..97ac368cafd 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -536,6 +536,7 @@ MeasurementKey, MeasurementType, PeriodicValue, + PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, SympyCondition, diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index a810bf108e9..81b39d5bcd0 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -65,3 +65,6 @@ from cirq.value.type_alias import TParamKey, TParamVal, TParamValComplex from cirq.value.value_equality_attr import value_equality + + +from cirq.value.prng import parse_prng, CustomPRNG, PRNG_OR_SEED_LIKE diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py new file mode 100644 index 00000000000..237f8717f02 --- /dev/null +++ b/cirq-core/cirq/value/prng.py @@ -0,0 +1,80 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import TypeVar, Union, overload + +import numpy as np + +from cirq._doc import document + + +class CustomPRNG(abc.ABC): ... + + +_CUSTOM_PRNG_T = TypeVar("_CUSTOM_PRNG_T", bound=CustomPRNG) +_PRNG_T = Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T] +_SEED_T = Union[int, None] +PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator, _CUSTOM_PRNG_T] + +document( + PRNG_OR_SEED_LIKE, + """A pseudorandom number generator or object that can be converted to one. + + If an integer or None, turns into a `np.random.Generator` seeded with that + value. + + If none of the above, it is used unmodified. In this case, it is assumed + that the object implements whatever methods are required for the use case + at hand. For example, it might be an existing instance of `np.random.Generator` + or `np.random.RandomState` or a custom pseudorandom number generator implementation + and in that case, it has to inherit `cirq.value.CustomPRNG`. + """, +) + + +@overload +def parse_prng(prng_or_seed: _SEED_T) -> np.random.Generator: ... + + +@overload +def parse_prng(prng_or_seed: np.random.Generator) -> np.random.Generator: ... + + +@overload +def parse_prng(prng_or_seed: np.random.RandomState) -> np.random.RandomState: ... + + +@overload +def parse_prng(prng_or_seed: _CUSTOM_PRNG_T) -> _CUSTOM_PRNG_T: ... + + +def parse_prng( + prng_or_seed: PRNG_OR_SEED_LIKE, +) -> Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T]: + """Interpret an object as a pseudorandom number generator. + + If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. + Otherwise, returns `prng_or_seed` unmodified. + + Args: + prng_or_seed: The object to be used as or converted to a pseudorandom + number generator. + + Returns: + The pseudorandom number generator object. + """ + if prng_or_seed is None or isinstance(prng_or_seed, int): + return np.random.default_rng(prng_or_seed) + return prng_or_seed diff --git a/cirq-core/cirq/value/prng_test.py b/cirq-core/cirq/value/prng_test.py new file mode 100644 index 00000000000..741ef012752 --- /dev/null +++ b/cirq-core/cirq/value/prng_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union +import numpy as np +import cirq + + +class TestPrng(cirq.value.CustomPRNG): + + def random(self, size): + return tuple(range(size)) + + +def _sample(prng): + return tuple(prng.random(10)) + + +def test_parse_rng() -> None: + eq = cirq.testing.EqualsTester() + + # An `np.random.Generator` or a seed. + group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)] + group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs] + eq.add_equality_group(*[_sample(g) for g in group]) + + # A None seed. + prng: np.random.Generator = cirq.value.parse_prng(None) + eq.add_equality_group(_sample(prng)) + + # Custom PRNG. + custom_prng: TestPrng = cirq.value.parse_prng(TestPrng()) + eq.add_equality_group(_sample(custom_prng)) + + # RandomState PRNG. + random_state: np.random.RandomState = np.random.RandomState(42) + eq.add_equality_group(_sample(cirq.value.parse_prng(random_state))) diff --git a/cirq-core/cirq/value/random_state.py b/cirq-core/cirq/value/random_state.py index ab884bd4705..38522a9574a 100644 --- a/cirq-core/cirq/value/random_state.py +++ b/cirq-core/cirq/value/random_state.py @@ -33,6 +33,8 @@ at hand. For example, it might be an existing instance of `np.random.RandomState` or a custom pseudorandom number generator implementation. + + Note: prefer to use cirq.PRNG_OR_SEED_LIKE. """, ) From 9ce26d6ce5f9d5a39d8d9d816ae181b76d6c0eb2 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 22 Apr 2024 12:00:20 -0700 Subject: [PATCH 2/7] serialization --- cirq-core/cirq/protocols/json_test_data/spec.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 22ae86051e0..3a127812fc2 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -154,6 +154,7 @@ 'QUANTUM_STATE_LIKE', 'QubitOrderOrList', 'RANDOM_STATE_OR_SEED_LIKE', + 'PRNG_OR_SEED_LIKE', 'STATE_VECTOR_LIKE', 'Sweepable', 'TParamKey', From 30a73e7034846a420bb47b319ee8e1c2d8b3e601 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Tue, 23 Apr 2024 13:47:59 -0700 Subject: [PATCH 3/7] nit --- cirq-core/cirq/value/prng_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/value/prng_test.py b/cirq-core/cirq/value/prng_test.py index 741ef012752..74afdc50150 100644 --- a/cirq-core/cirq/value/prng_test.py +++ b/cirq-core/cirq/value/prng_test.py @@ -44,5 +44,5 @@ def test_parse_rng() -> None: eq.add_equality_group(_sample(custom_prng)) # RandomState PRNG. - random_state: np.random.RandomState = np.random.RandomState(42) - eq.add_equality_group(_sample(cirq.value.parse_prng(random_state))) + random_state: np.random.RandomState = cirq.value.parse_prng(np.random.RandomState(42)) + eq.add_equality_group(_sample(random_state)) From c6cfb378853407cf919ba91eb2cdcbdd1a1fd694 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Wed, 8 May 2024 11:43:47 -0700 Subject: [PATCH 4/7] address comments --- cirq-core/cirq/value/__init__.py | 2 +- cirq-core/cirq/value/prng.py | 65 +++++++++-------------- cirq-core/cirq/value/prng_test.py | 26 ++++----- cirq-core/cirq/value/random_state.py | 10 +++- cirq-core/cirq/value/random_state_test.py | 2 + 5 files changed, 50 insertions(+), 55 deletions(-) diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 81b39d5bcd0..4847b1ff5a5 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -67,4 +67,4 @@ from cirq.value.value_equality_attr import value_equality -from cirq.value.prng import parse_prng, CustomPRNG, PRNG_OR_SEED_LIKE +from cirq.value.prng import parse_prng, PRNG_OR_SEED_LIKE diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py index 237f8717f02..1dbd2d49f8d 100644 --- a/cirq-core/cirq/value/prng.py +++ b/cirq-core/cirq/value/prng.py @@ -12,61 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -from typing import TypeVar, Union, overload +from typing import Union +import numbers import numpy as np from cirq._doc import document +from cirq.value.random_state import RANDOM_STATE_OR_SEED_LIKE - -class CustomPRNG(abc.ABC): ... - - -_CUSTOM_PRNG_T = TypeVar("_CUSTOM_PRNG_T", bound=CustomPRNG) -_PRNG_T = Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T] _SEED_T = Union[int, None] -PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator, _CUSTOM_PRNG_T] +PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator] document( PRNG_OR_SEED_LIKE, """A pseudorandom number generator or object that can be converted to one. - If an integer or None, turns into a `np.random.Generator` seeded with that - value. - - If none of the above, it is used unmodified. In this case, it is assumed - that the object implements whatever methods are required for the use case - at hand. For example, it might be an existing instance of `np.random.Generator` - or `np.random.RandomState` or a custom pseudorandom number generator implementation - and in that case, it has to inherit `cirq.value.CustomPRNG`. + If is an integer or None, turns into a `np.random.Generator` seeded with that value. + If is an instance of `np.random.Generator` or a subclass of it, return as is. + If is an instance of `np.random.RandomState` or has a `randint` method, returns + `np.random.default_rng(rs.randint(2**63 - 1))` """, ) -@overload -def parse_prng(prng_or_seed: _SEED_T) -> np.random.Generator: ... - - -@overload -def parse_prng(prng_or_seed: np.random.Generator) -> np.random.Generator: ... - - -@overload -def parse_prng(prng_or_seed: np.random.RandomState) -> np.random.RandomState: ... - - -@overload -def parse_prng(prng_or_seed: _CUSTOM_PRNG_T) -> _CUSTOM_PRNG_T: ... - - def parse_prng( - prng_or_seed: PRNG_OR_SEED_LIKE, -) -> Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T]: + prng_or_seed: Union[PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE] +) -> np.random.Generator: """Interpret an object as a pseudorandom number generator. + If `prng_or_seed` is an `np.random.Generator`, return it unmodified. If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. - Otherwise, returns `prng_or_seed` unmodified. + If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method, + returns `np.random.default_rng(prng_or_seed.randint(2**63 - 1))`. Args: prng_or_seed: The object to be used as or converted to a pseudorandom @@ -74,7 +51,17 @@ def parse_prng( Returns: The pseudorandom number generator object. + + Raises: + TypeError: If `prng_or_seed` is can't be converted to an np.random.Generator. """ - if prng_or_seed is None or isinstance(prng_or_seed, int): - return np.random.default_rng(prng_or_seed) - return prng_or_seed + if isinstance(prng_or_seed, np.random.Generator): + return prng_or_seed + if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral): + return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed)) + if isinstance(prng_or_seed, np.random.RandomState): + return np.random.default_rng(prng_or_seed.randint(2**63 - 1)) + randint = getattr(prng_or_seed, "randint", None) + if randint is not None: + return np.random.default_rng(randint(2**63 - 1)) + raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator") diff --git a/cirq-core/cirq/value/prng_test.py b/cirq-core/cirq/value/prng_test.py index 74afdc50150..9ca84ae46bf 100644 --- a/cirq-core/cirq/value/prng_test.py +++ b/cirq-core/cirq/value/prng_test.py @@ -13,14 +13,11 @@ # limitations under the License. from typing import List, Union -import numpy as np -import cirq - -class TestPrng(cirq.value.CustomPRNG): +import pytest +import numpy as np - def random(self, size): - return tuple(range(size)) +import cirq def _sample(prng): @@ -36,13 +33,16 @@ def test_parse_rng() -> None: eq.add_equality_group(*[_sample(g) for g in group]) # A None seed. - prng: np.random.Generator = cirq.value.parse_prng(None) + prng = cirq.value.parse_prng(None) eq.add_equality_group(_sample(prng)) - # Custom PRNG. - custom_prng: TestPrng = cirq.value.parse_prng(TestPrng()) - eq.add_equality_group(_sample(custom_prng)) - # RandomState PRNG. - random_state: np.random.RandomState = cirq.value.parse_prng(np.random.RandomState(42)) - eq.add_equality_group(_sample(random_state)) + prng = cirq.value.parse_prng(np.random.RandomState(42)) + eq.add_equality_group(_sample(prng)) + + # np.random module + prng = cirq.value.parse_prng(np.random) + eq.add_equality_group(_sample(prng)) + + with pytest.raises(TypeError): + _ = cirq.value.parse_prng(1.0) diff --git a/cirq-core/cirq/value/random_state.py b/cirq-core/cirq/value/random_state.py index 38522a9574a..771b80634f0 100644 --- a/cirq-core/cirq/value/random_state.py +++ b/cirq-core/cirq/value/random_state.py @@ -28,6 +28,9 @@ If an integer, turns into a `np.random.RandomState` seeded with that integer. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. + If none of the above, it is used unmodified. In this case, it is assumed that the object implements whatever methods are required for the use case at hand. For example, it might be an existing instance of @@ -43,8 +46,9 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran """Interpret an object as a pseudorandom number generator. If `random_state` is None, returns the module `np.random`. - If `random_state` is an integer, returns - `np.random.RandomState(random_state)`. + If `random_state` is an integer, returns `np.random.RandomState(random_state)`. + If `random_state` is an instance of `np.random.Generator`, returns a + `np.random.RandomState` seeded with `random_state.bit_generator`. Otherwise, returns `random_state` unmodified. Args: @@ -58,5 +62,7 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran return cast(np.random.RandomState, np.random) elif isinstance(random_state, int): return np.random.RandomState(random_state) + elif isinstance(random_state, np.random.Generator): + return np.random.RandomState(random_state.bit_generator) else: return cast(np.random.RandomState, random_state) diff --git a/cirq-core/cirq/value/random_state_test.py b/cirq-core/cirq/value/random_state_test.py index 694965ea0a9..3f74ea0932e 100644 --- a/cirq-core/cirq/value/random_state_test.py +++ b/cirq-core/cirq/value/random_state_test.py @@ -42,3 +42,5 @@ def rand(prng): vals = [prng.rand() for prng in prngs] eq = cirq.testing.EqualsTester() eq.add_equality_group(*vals) + + eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand()) From d5ade4f4365a571b18b2d9ebfbe01d2e490ab2e8 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Wed, 8 May 2024 11:47:22 -0700 Subject: [PATCH 5/7] nit --- cirq-core/cirq/value/prng.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py index 1dbd2d49f8d..8bdfa731a7b 100644 --- a/cirq-core/cirq/value/prng.py +++ b/cirq-core/cirq/value/prng.py @@ -20,7 +20,6 @@ from cirq._doc import document from cirq.value.random_state import RANDOM_STATE_OR_SEED_LIKE -_SEED_T = Union[int, None] PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator] document( From 1e9c7ab362ebdf7135ca93571f5a6c7428840ed3 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Wed, 8 May 2024 12:02:12 -0700 Subject: [PATCH 6/7] nit --- cirq-core/cirq/value/prng.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py index 8bdfa731a7b..ea7492413bd 100644 --- a/cirq-core/cirq/value/prng.py +++ b/cirq-core/cirq/value/prng.py @@ -29,7 +29,7 @@ If is an integer or None, turns into a `np.random.Generator` seeded with that value. If is an instance of `np.random.Generator` or a subclass of it, return as is. If is an instance of `np.random.RandomState` or has a `randint` method, returns - `np.random.default_rng(rs.randint(2**63 - 1))` + `np.random.default_rng(rs.randint(2**31 - 1))` """, ) @@ -42,7 +42,7 @@ def parse_prng( If `prng_or_seed` is an `np.random.Generator`, return it unmodified. If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method, - returns `np.random.default_rng(prng_or_seed.randint(2**63 - 1))`. + returns `np.random.default_rng(prng_or_seed.randint(2**31 - 1))`. Args: prng_or_seed: The object to be used as or converted to a pseudorandom @@ -59,8 +59,8 @@ def parse_prng( if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral): return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed)) if isinstance(prng_or_seed, np.random.RandomState): - return np.random.default_rng(prng_or_seed.randint(2**63 - 1)) + return np.random.default_rng(prng_or_seed.randint(2**31 - 1)) randint = getattr(prng_or_seed, "randint", None) if randint is not None: - return np.random.default_rng(randint(2**63 - 1)) + return np.random.default_rng(randint(2**31 - 1)) raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator") From aff9556bb437a7855819b5ca0993d5cf69b51f56 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Wed, 8 May 2024 12:02:47 -0700 Subject: [PATCH 7/7] nit --- cirq-core/cirq/value/prng.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/value/prng.py b/cirq-core/cirq/value/prng.py index ea7492413bd..86440ccee3d 100644 --- a/cirq-core/cirq/value/prng.py +++ b/cirq-core/cirq/value/prng.py @@ -29,7 +29,7 @@ If is an integer or None, turns into a `np.random.Generator` seeded with that value. If is an instance of `np.random.Generator` or a subclass of it, return as is. If is an instance of `np.random.RandomState` or has a `randint` method, returns - `np.random.default_rng(rs.randint(2**31 - 1))` + `np.random.default_rng(rs.randint(2**31))` """, ) @@ -42,7 +42,7 @@ def parse_prng( If `prng_or_seed` is an `np.random.Generator`, return it unmodified. If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method, - returns `np.random.default_rng(prng_or_seed.randint(2**31 - 1))`. + returns `np.random.default_rng(prng_or_seed.randint(2**31))`. Args: prng_or_seed: The object to be used as or converted to a pseudorandom @@ -59,8 +59,8 @@ def parse_prng( if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral): return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed)) if isinstance(prng_or_seed, np.random.RandomState): - return np.random.default_rng(prng_or_seed.randint(2**31 - 1)) + return np.random.default_rng(prng_or_seed.randint(2**31)) randint = getattr(prng_or_seed, "randint", None) if randint is not None: - return np.random.default_rng(randint(2**31 - 1)) + return np.random.default_rng(randint(2**31)) raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator")