Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
RELEASE_TYPE: minor

This release changes :class:`hypothesis.stateful.Bundle` to use the internals of
:func:`~hypothesis.strategies.sampled_from`, improving the `filter` and `map` methods.
In addition to performance improvements, you can now ``consumes(some_bundle).filter(...)``!

Thanks to Reagan Lee for this feature (:issue:`3944`).
175 changes: 132 additions & 43 deletions hypothesis-python/src/hypothesis/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from hypothesis.errors import InvalidArgument, InvalidDefinition
from hypothesis.internal.compat import add_note, batched
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.conjecture.engine import BUFFER_SIZE
from hypothesis.internal.conjecture.junkdrawer import gc_cumulative_time
from hypothesis.internal.healthcheck import fail_health_check
Expand All @@ -54,8 +55,11 @@
from hypothesis.strategies._internal.strategies import (
Ex,
OneOfStrategy,
SampledFromStrategy,
SampledFromTransformationsT,
SearchStrategy,
check_strategy,
filter_not_satisfied,
)
from hypothesis.vendor.pretty import RepresentationPrinter

Expand Down Expand Up @@ -191,11 +195,11 @@ def output(s):
data = dict(data)
for k, v in list(data.items()):
if isinstance(v, VarReference):
data[k] = machine.names_to_values[v.name]
data[k] = v.value
elif isinstance(v, list) and all(
isinstance(item, VarReference) for item in v
):
data[k] = [machine.names_to_values[item.name] for item in v]
data[k] = [item.value for item in v]

label = f"execute:rule:{rule.function.__name__}"
start = perf_counter()
Expand Down Expand Up @@ -310,7 +314,7 @@ def __init__(self) -> None:
# copy since we pop from this as we run initialize rules.
self._initialize_rules_to_run = setup_state.initializers.copy()

self.bundles: dict[str, list] = {}
self.bundles: dict[str, list[str]] = {}
self.names_counters: collections.Counter = collections.Counter()
self.names_list: list[str] = []
self.names_to_values: dict[str, Any] = {}
Expand Down Expand Up @@ -439,7 +443,7 @@ def printer(obj, p, cycle, name=name):
if not _is_singleton(result):
self.__printer.singleton_pprinters.setdefault(id(result), printer)
self.names_to_values[name] = result
self.bundles.setdefault(target, []).append(VarReference(name))
self.bundles.setdefault(target, []).append(name)

def check_invariants(self, settings, output, runtimes):
for invar in self.invariants:
Expand Down Expand Up @@ -509,8 +513,12 @@ def __post_init__(self):
assert not isinstance(v, BundleReferenceStrategy)
if isinstance(v, Bundle):
bundles.append(v)
consume = isinstance(v, BundleConsumer)
v = BundleReferenceStrategy(v.name, consume=consume)
v = BundleReferenceStrategy(
v.name,
consume=v.consume,
force_repr=v.force_repr,
transformations=v._transformations,
)
self.arguments_strategies[k] = v
self.bundles = tuple(bundles)

Expand Down Expand Up @@ -544,25 +552,60 @@ def __hash__(self):
self_strategy = st.runner()


class BundleReferenceStrategy(SearchStrategy):
def __init__(self, name: str, *, consume: bool = False):
super().__init__()
class BundleReferenceStrategy(SampledFromStrategy[Ex]):

def __init__(
self,
name: str,
*,
consume: bool = False,
force_repr: Optional[str] = None,
transformations: SampledFromTransformationsT = (),
):
super().__init__(
[...],
force_repr=force_repr,
transformations=transformations,
) # Some random items that'll get replaced in do_draw
self.name = name
self.consume = consume

def do_draw(self, data):
machine = data.draw(self_strategy)
bundle = machine.bundle(self.name)
if not bundle:
def get_transformed_value(self, name: str) -> Ex:
value = self.machine.names_to_values[name]
return self._transform(value)

def get_element(self, i: int) -> int:
idx = self.elements[i]
name = self.bundle[idx]
value = self.get_transformed_value(name)
if value is filter_not_satisfied:
return filter_not_satisfied
return idx

def do_draw(self, data: ConjectureData) -> Ex:
self.machine = data.draw(self_strategy)
self.bundle = self.machine.bundle(self.name)
if not self.bundle:
data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}")

# We use both self.bundle and self.elements to make sure an index is
# used to safely pop.

# Shrink towards the right rather than the left. This makes it easier
# to delete data generated earlier, as when the error is towards the
# end there can be a lot of hard to remove padding.
position = data.draw_integer(0, len(bundle) - 1, shrink_towards=len(bundle))
self.elements = range(len(self.bundle))[::-1]

position = super().do_draw(data)
name = self.bundle[position]
if self.consume:
return bundle.pop(position) # pragma: no cover # coverage is flaky here
else:
return bundle[position]
self.bundle.pop(position) # pragma: no cover # coverage is flaky here

value = self.get_transformed_value(name)

# We need both reference and the value itself to pretty-print deterministically
# and maintain any transformations that is bundle-specific
return VarReference(name, value)


class Bundle(SearchStrategy[Ex]):
Expand All @@ -585,26 +628,80 @@ class MyStateMachine(RuleBasedStateMachine):

If the ``consume`` argument is set to True, then all values that are
drawn from this bundle will be consumed (as above) when requested.

Bundles can be combined with |.map| and |.filter|:

.. code-block:: python

class Machine(RuleBasedStateMachine):
values = Bundle("values")

@initialize(target=values)
def populate_values(self):
return multiple(1, 2)

@rule(n=buns.map(lambda x: -x))
def use_map(self, n):
pass

@rule(n=buns.filter(lambda x: x > 1))
def use_filter(self, n):
pass
"""

def __init__(
self, name: str, *, consume: bool = False, draw_references: bool = True
self,
name: str,
*,
consume: bool = False,
force_repr: Optional[str] = None,
transformations: SampledFromTransformationsT = (),
) -> None:
super().__init__()
self.name = name
self.__reference_strategy = BundleReferenceStrategy(name, consume=consume)
self.draw_references = draw_references
self.__reference_strategy = BundleReferenceStrategy(
name,
consume=consume,
force_repr=force_repr,
transformations=transformations,
)

def do_draw(self, data):
machine = data.draw(self_strategy)
reference = data.draw(self.__reference_strategy)
return machine.names_to_values[reference.name]
@property
def consume(self):
return self.__reference_strategy.consume

@property
def force_repr(self):
return self.__reference_strategy.force_repr

@property
def _transformations(self):
return self.__reference_strategy._transformations

def do_draw(self, data: ConjectureData) -> Ex:
self.machine = data.draw(self_strategy)
var_reference = data.draw(self.__reference_strategy)
assert isinstance(var_reference, VarReference)
return var_reference.value

def __with_transform(self, method, fn):
return Bundle(
self.name,
consume=self.consume,
force_repr=self.force_repr,
transformations=(*self._transformations, (method, fn)),
)

def filter(self, condition):
return self.__with_transform("filter", condition)

def map(self, pack):
return self.__with_transform("map", pack)

def __repr__(self):
consume = self.__reference_strategy.consume
if consume is False:
if self.consume is False:
return f"Bundle(name={self.name!r})"
return f"Bundle(name={self.name!r}, {consume=})"
return f"Bundle(name={self.name!r}, {self.consume=})"

def calc_is_empty(self, recur):
# We assume that a bundle will grow over time
Expand All @@ -617,15 +714,6 @@ def _available(self, data):
machine = data.draw(self_strategy)
return bool(machine.bundle(self.name))

def flatmap(self, expand):
if self.draw_references:
return type(self)(
self.name,
consume=self.__reference_strategy.consume,
draw_references=False,
).flatmap(expand)
return super().flatmap(expand)

def __hash__(self):
# Making this hashable means we hit the fast path of "everything is
# hashable" in st.sampled_from label calculation when sampling which rule
Expand All @@ -635,11 +723,6 @@ def __hash__(self):
return hash(("Bundle", self.name))


class BundleConsumer(Bundle[Ex]):
def __init__(self, bundle: Bundle[Ex]) -> None:
super().__init__(bundle.name, consume=True)


def consumes(bundle: Bundle[Ex]) -> SearchStrategy[Ex]:
"""When introducing a rule in a RuleBasedStateMachine, this function can
be used to mark bundles from which each value used in a step with the
Expand All @@ -655,7 +738,12 @@ def consumes(bundle: Bundle[Ex]) -> SearchStrategy[Ex]:
"""
if not isinstance(bundle, Bundle):
raise TypeError("Argument to be consumed must be a bundle.")
return BundleConsumer(bundle)
return Bundle(
bundle.name,
consume=True,
transformations=bundle._transformations,
force_repr=bundle.force_repr,
)


@dataclass
Expand Down Expand Up @@ -700,7 +788,7 @@ def _convert_targets(targets, target):
)
raise InvalidArgument(msg % (t, type(t)))
while isinstance(t, Bundle):
if isinstance(t, BundleConsumer):
if t.consume:
note_deprecation(
f"Using consumes({t.name}) doesn't makes sense in this context. "
"This will be an error in a future version of Hypothesis.",
Expand Down Expand Up @@ -944,6 +1032,7 @@ def rule_wrapper(*args, **kwargs):
@dataclass
class VarReference:
name: str
value: Any


# There are multiple alternatives for annotating the `precond` type, all of them
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,12 @@ def is_hashable(value: object) -> bool:
return _is_hashable(value)[0]


SampledFromTransformationsT: "TypeAlias" = tuple[
tuple[Literal["filter", "map"], Callable[[Ex], Any]],
...,
]


class SampledFromStrategy(SearchStrategy[Ex]):
"""A strategy which samples from a set of elements. This is essentially
equivalent to using a OneOfStrategy over Just strategies but may be more
Expand All @@ -578,10 +584,7 @@ def __init__(
*,
force_repr: Optional[str] = None,
force_repr_braces: Optional[tuple[str, str]] = None,
transformations: tuple[
tuple[Literal["filter", "map"], Callable[[Ex], Any]],
...,
] = (),
transformations: SampledFromTransformationsT = (),
):
super().__init__()
self.elements = cu.check_sample(elements, "sampled_from")
Expand Down Expand Up @@ -698,7 +701,7 @@ def _transform(
# conservative than necessary
element: Ex, # type: ignore
) -> Union[Ex, UniqueIdentifier]:
# Used in UniqueSampledListStrategy
# Used in UniqueSampledListStrategy and BundleStrategy
for name, f in self._transformations:
if name == "map":
result = f(element)
Expand Down
Loading
Loading