-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add support for np.random.Generator #6566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
834ee27
9ce26d6
30a73e7
f455207
c6cfb37
0fc07df
d5ade4f
1e9c7ab
aff9556
b06f456
770e8fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import abc | ||
from typing import TypeVar, Union, overload | ||
|
||
import numpy as np | ||
|
||
from cirq._doc import document | ||
|
||
|
||
class CustomPRNG(abc.ABC): ... | ||
|
||
|
||
_CUSTOM_PRNG_T = TypeVar("_CUSTOM_PRNG_T", bound=CustomPRNG) | ||
_PRNG_T = Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T] | ||
_SEED_T = Union[int, None] | ||
PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator, _CUSTOM_PRNG_T] | ||
|
||
document( | ||
PRNG_OR_SEED_LIKE, | ||
"""A pseudorandom number generator or object that can be converted to one. | ||
|
||
If an integer or None, turns into a `np.random.Generator` seeded with that | ||
value. | ||
|
||
If none of the above, it is used unmodified. In this case, it is assumed | ||
that the object implements whatever methods are required for the use case | ||
at hand. For example, it might be an existing instance of `np.random.Generator` | ||
or `np.random.RandomState` or a custom pseudorandom number generator implementation | ||
and in that case, it has to inherit `cirq.value.CustomPRNG`. | ||
""", | ||
) | ||
|
||
|
||
@overload | ||
def parse_prng(prng_or_seed: _SEED_T) -> np.random.Generator: ... | ||
|
||
|
||
@overload | ||
def parse_prng(prng_or_seed: np.random.Generator) -> np.random.Generator: ... | ||
|
||
|
||
@overload | ||
def parse_prng(prng_or_seed: np.random.RandomState) -> np.random.RandomState: ... | ||
|
||
|
||
@overload | ||
def parse_prng(prng_or_seed: _CUSTOM_PRNG_T) -> _CUSTOM_PRNG_T: ... | ||
|
||
|
||
def parse_prng( | ||
prng_or_seed: PRNG_OR_SEED_LIKE, | ||
) -> Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type of different types tends to be a code smell. Such returned value is less useful for type checking. In addition, Generator and RandomState (not to mention _CUSTOM_PRNG_T) have different APIs so the I would propose an alternative approach: (1) convert the (2) extend parse_random_state to accept a Generator object and convert it to RandomState. (3) add method With these steps in place, we can keep all the existing interfaces that take This would also avoid bifurcation between RANDOM_STATE_OR_SEED_LIKE and PRNG_OR_SEED_LIKE types that may need several major releases to clear up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ptal |
||
"""Interpret an object as a pseudorandom number generator. | ||
|
||
If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. | ||
Otherwise, returns `prng_or_seed` unmodified. | ||
|
||
Args: | ||
prng_or_seed: The object to be used as or converted to a pseudorandom | ||
number generator. | ||
|
||
Returns: | ||
The pseudorandom number generator object. | ||
""" | ||
if prng_or_seed is None or isinstance(prng_or_seed, int): | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return np.random.default_rng(prng_or_seed) | ||
return prng_or_seed |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,48 @@ | ||||||||||||||||||||||||||||
# Copyright 2024 The Cirq Developers | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# https://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
from typing import List, Union | ||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||
import cirq | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
class TestPrng(cirq.value.CustomPRNG): | ||||||||||||||||||||||||||||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def random(self, size): | ||||||||||||||||||||||||||||
return tuple(range(size)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _sample(prng): | ||||||||||||||||||||||||||||
return tuple(prng.random(10)) | ||||||||||||||||||||||||||||
Comment on lines
+23
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this. One output from |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def test_parse_rng() -> None: | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
eq = cirq.testing.EqualsTester() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# An `np.random.Generator` or a seed. | ||||||||||||||||||||||||||||
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)] | ||||||||||||||||||||||||||||
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs] | ||||||||||||||||||||||||||||
eq.add_equality_group(*[_sample(g) for g in group]) | ||||||||||||||||||||||||||||
Comment on lines
+30
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us not check cross-group inequality. Following the
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# A None seed. | ||||||||||||||||||||||||||||
prng: np.random.Generator = cirq.value.parse_prng(None) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(prng)) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a noop check for a single value. Perhaps replace with
if you are OK with the previous suggestion to have a singleton generator for None. |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Custom PRNG. | ||||||||||||||||||||||||||||
custom_prng: TestPrng = cirq.value.parse_prng(TestPrng()) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(custom_prng)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# RandomState PRNG. | ||||||||||||||||||||||||||||
random_state: np.random.RandomState = cirq.value.parse_prng(np.random.RandomState(42)) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(random_state)) |
Uh oh!
There was an error while loading. Please reload this page.