Skip to content

Commit 17533c9

Browse files
committed
✨ shortcut method on on, on_global and use
1 parent 16c5ce9 commit 17533c9

File tree

5 files changed

+80
-44
lines changed

5 files changed

+80
-44
lines changed

arclet/letoderea/decorate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def wrapper(target: TCallable, /) -> TCallable:
5252
return wrapper
5353

5454

55-
class _Check(Propagator):
55+
class Check(Propagator):
5656
def __init__(self, result: bool, priority: int = 0):
5757
self.predicates = []
5858
self.result = result
5959
self.priority = priority
6060

6161
if TYPE_CHECKING:
62-
def derive(self, predicate: "_Check" | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool) -> Self: ...
62+
def derive(self, predicate: "Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool") -> Self: ...
6363
else:
64-
def derive(self, predicate: Union["_Check", Callable[..., bool], Callable[..., Awaitable[bool]], Deref]) -> Self:
65-
if isinstance(predicate, _Check):
64+
def derive(self, predicate: Union["Check", Callable[..., bool], Callable[..., Awaitable[bool]], Deref]) -> Self:
65+
if isinstance(predicate, Check):
6666
self.predicates.extend(predicate.predicates)
6767
else:
6868
self.predicates.append(generate(predicate) if isinstance(predicate, Deref) else predicate)
@@ -105,10 +105,10 @@ def priority(self, value: int):
105105
return self
106106

107107
if TYPE_CHECKING:
108-
def __call__(self, predicate: "_Check" | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool) -> _Check: ...
108+
def __call__(self, predicate: "Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool") -> Check: ...
109109
else:
110-
def __call__(self, predicate: Union["_Check", Callable[..., bool], Callable[..., Awaitable[bool]], Deref]) -> _Check:
111-
return _Check(self.result, self._priority).derive(generate(predicate) if isinstance(predicate, Deref) else predicate)
110+
def __call__(self, predicate: Union["Check", Callable[..., bool], Callable[..., Awaitable[bool]], Deref]) -> Check:
111+
return Check(self.result, self._priority).derive(generate(predicate) if isinstance(predicate, Deref) else predicate)
112112

113113
__and__ = __call__
114114
__or__ = __call__

arclet/letoderea/scope.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from contextlib import contextmanager
45
from secrets import token_urlsafe
5-
from typing import Any, TypeVar
6-
from collections.abc import Callable
6+
from typing import Any, TypeVar, Generic
7+
from collections.abc import Callable, Awaitable
78

89
from tarina import ContextModel
910

1011
from .provider import TProviders, Provider, ProviderFactory, global_providers
1112
from .publisher import Publisher, _publishers, filter_publisher
1213
from .subscriber import Propagator, Subscriber
14+
from .decorate import Check, enter_if, bypass_if
1315

1416
T = TypeVar("T")
1517

@@ -20,6 +22,40 @@
2022
global_propagators: list[Propagator] = []
2123

2224

25+
@dataclass
26+
class RegisterWrapper(Generic[T]):
27+
_scope: Scope
28+
_event: type | None
29+
_priority: int
30+
_providers: TProviders
31+
_propagators: list[Propagator]
32+
_publisher: Publisher | None
33+
_pub_id: str
34+
_once: bool
35+
_skip_req_missing: bool
36+
37+
def if_(self, predicate: Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool, priority: int = 0):
38+
self._propagators.append(enter_if(predicate) / priority)
39+
return self
40+
41+
def unless(self, predicate: Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool, priority: int = 0):
42+
self._propagators.append(bypass_if(predicate) / priority)
43+
return self
44+
45+
def propagate(self, *propagators: Propagator):
46+
self._propagators.extend(propagators)
47+
return self
48+
49+
def __call__(self, func: Callable, /) -> Subscriber[T]:
50+
if isinstance(func, Subscriber):
51+
func = func.callable_target
52+
res = Subscriber(func, priority=self._priority, providers=self._providers, dispose=self._scope.remove_subscriber, once=self._once, skip_req_missing=self._skip_req_missing, _listen=self._event)
53+
res.propagates(*self._propagators)
54+
if not self._publisher or (self._publisher and self._publisher.check_subscriber(res)):
55+
self._scope.subscribers[res.id] = (res, self._pub_id)
56+
return res
57+
58+
2359
class Scope:
2460
global_skip_req_missing = False
2561

@@ -85,24 +121,7 @@ def register(self, func: Callable[..., Any] | None = None, event: type | None =
85121
event_providers = _pub.providers
86122
_listen = event
87123

88-
def register_wrapper(exec_target: Callable, /) -> Subscriber:
89-
if isinstance(exec_target, Subscriber):
90-
exec_target = exec_target.callable_target
91-
_providers = [*global_providers, *event_providers, *self.providers, *providers]
92-
res = Subscriber(
93-
exec_target,
94-
priority=priority,
95-
providers=_providers,
96-
dispose=self.remove_subscriber,
97-
once=once,
98-
skip_req_missing=_skip_req_missing,
99-
_listen=_listen,
100-
)
101-
res.propagates(*global_propagators, *self.propagators)
102-
if not _pub or (_pub and _pub.check_subscriber(res)):
103-
self.subscribers[res.id] = (res, pub_id)
104-
return res
105-
124+
register_wrapper = RegisterWrapper(self, _listen, priority, [*global_providers, *event_providers, *self.providers, *providers], [*global_propagators, *self.propagators], _pub, pub_id, once, _skip_req_missing)
106125
if func:
107126
return register_wrapper(func)
108127
return register_wrapper

arclet/letoderea/scope.pyi

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from contextlib import contextmanager
2+
from dataclasses import dataclass
23
from typing import Any, ClassVar, TypeVar, overload, Generic
34
from collections.abc import Callable, Awaitable, Generator, AsyncGenerator
45

56
from tarina import ContextModel
7+
from typing_extensions import Self
68

9+
from .decorate import Check
710
from .provider import TProviders, Provider, ProviderFactory
811
from .exceptions import ExitState
912
from .publisher import Publisher
1013
from .subscriber import Propagator, Subscriber
1114
from .typing import Resultable
1215

1316
T = TypeVar("T")
17+
TC = TypeVar("TC")
1418
T1 = TypeVar("T1")
1519
_scopes: dict[str, Scope]
1620
scope_ctx: ContextModel[Scope]
@@ -46,15 +50,31 @@ class Scope:
4650
def configure(skip_req_missing: bool = False) -> None: ...
4751

4852

49-
class _Wrapper(Generic[T]):
53+
class RegisterWrapper(Generic[T, TC]):
54+
_scope: Scope
55+
_event: type | None
56+
_priority: int
57+
_providers: TProviders
58+
_propagators: list[Propagator]
59+
_publisher: Publisher | None
60+
_pub_id: str
61+
_once: bool
62+
_skip_req_missing: bool
63+
64+
def if_(self, predicate: Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool, priority: int = 0) -> Self: ...
65+
def unless(self, predicate: Check | Callable[..., bool] | Callable[..., Awaitable[bool]] | bool, priority: int = 0) -> Self: ...
66+
def propagate(self, *propagators: Propagator) -> Self: ...
67+
def __init__(self, _scope: Scope, _event: type | None, _priority: int, _providers: TProviders, _propagators: list[Propagator], _publisher: Publisher | None, _pub_id: str, _once: bool = False, _skip_req_missing: bool | None = None): ...
68+
@overload
69+
def __call__(self: RegisterWrapper[None, Callable], func: Callable[..., T1]) -> Subscriber[T1]: ...
5070
@overload
51-
def __call__(self, func: Callable[..., AsyncGenerator[T | ExitState | None, None]]) -> Subscriber[AsyncGenerator[T, None]]: ...
71+
def __call__(self: RegisterWrapper[T, None], func: Callable[..., AsyncGenerator[T | ExitState | None, None]]) -> Subscriber[AsyncGenerator[T, None]]: ...
5272
@overload
53-
def __call__(self, func: Callable[..., Generator[T | ExitState | None, None, None]]) -> Subscriber[Generator[T, None, None]]: ...
73+
def __call__(self: RegisterWrapper[T, None], func: Callable[..., Generator[T | ExitState | None, None, None]]) -> Subscriber[Generator[T, None, None]]: ...
5474
@overload
55-
def __call__(self, func: Callable[..., Awaitable[T | ExitState | None]]) -> Subscriber[Awaitable[T]]: ...
75+
def __call__(self: RegisterWrapper[T, None], func: Callable[..., Awaitable[T | ExitState | None]]) -> Subscriber[Awaitable[T]]: ...
5676
@overload
57-
def __call__(self, func: Callable[..., T | ExitState | None]) -> Subscriber[T]: ...
77+
def __call__(self: RegisterWrapper[T, None], func: Callable[..., T | ExitState | None]) -> Subscriber[T]: ...
5878

5979

6080
@overload
@@ -66,15 +86,15 @@ def on(event: type[Resultable[T1]], func: Callable[..., Awaitable[T1 | ExitState
6686
@overload
6787
def on(event: type[Resultable[T1]], func: Callable[..., T1 | ExitState | None], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T1]: ...
6888
@overload
69-
def on(event: type[Resultable[T1]], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> _Wrapper[T1]: ...
89+
def on(event: type[Resultable[T1]], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[T1, None]: ...
7090
@overload
7191
def on(event: type[Any], func: Callable[..., T], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T]: ... # type: ignore
7292
@overload
73-
def on(event: type[Any], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Callable[[Callable[..., T]], Subscriber[T]]: ... # type: ignore
93+
def on(event: type[Any], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[None, Callable]: ... # type: ignore
7494
@overload
7595
def on_global(func: Callable[..., T], *, priority: int = 16, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T]: ...
7696
@overload
77-
def on_global(*, priority: int = 16, once: bool = False, skip_req_missing: bool | None = None) -> Callable[[Callable[..., T]], Subscriber[T]]: ...
97+
def on_global(*, priority: int = 16, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[None, Callable]: ...
7898
@overload
7999
def use(pub: Publisher[Resultable[T1]], func: Callable[..., Generator[T1 | ExitState | None, None, None]], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[Generator[T1, None, None]]: ...
80100
@overload
@@ -84,12 +104,12 @@ def use(pub: Publisher[Resultable[T1]], func: Callable[..., Awaitable[T1 | ExitS
84104
@overload
85105
def use(pub: Publisher[Resultable[T1]], func: Callable[..., T1 | ExitState | None], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T1]: ...
86106
@overload
87-
def use(pub: Publisher[Resultable[T1]], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> _Wrapper[T1]: ...
107+
def use(pub: Publisher[Resultable[T1]], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[T1, None]: ...
88108
@overload
89109
def use(pub: Publisher[Any], func: Callable[..., T], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T]: ...
90110
@overload
91-
def use(pub: Publisher[Any], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Callable[[Callable[..., T]], Subscriber[T]]: ...
111+
def use(pub: Publisher[Any], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[None, Callable]: ...
92112
@overload
93113
def use(pub: str, func: Callable[..., T], *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Subscriber[T]: ...
94114
@overload
95-
def use(pub: str, *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> Callable[[Callable[..., T]], Subscriber[T]]: ...
115+
def use(pub: str, *, priority: int = 16, providers: TProviders | None = None, once: bool = False, skip_req_missing: bool | None = None) -> RegisterWrapper[None, Callable]: ...

tests/test_propagate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,7 @@ def compose(self):
212212
yield lambda: executed.append(1), True
213213
yield Interval(0.3)
214214

215-
@le.on(PropagateEvent)
216-
@le.propagate(MyPropagator())
215+
@le.on(PropagateEvent).propagate(MyPropagator())
217216
async def s(last_time, foo: str):
218217
assert last_time is None or isinstance(last_time, datetime)
219218
executed.append(foo)

tests/test_shorcuts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,12 @@ async def s1(a: Annotated[str, func]):
5858
async def test_deref():
5959
executed = []
6060

61-
@on_global
62-
@enter_if(deref(ShortcutEvent).flag) / 100
61+
@on_global().if_(deref(ShortcutEvent).flag, 100)
6362
async def s(flag: Annotated[bool, "flag"]):
6463
assert flag is True
6564
executed.append(1)
6665

67-
@on_global
68-
@bypass_if(deref(ShortcutEvent).flag)
66+
@on_global().unless(deref(ShortcutEvent).flag)
6967
async def s1(
7068
flag: Annotated[bool, deref(ShortcutEvent).flag],
7169
t: Annotated[int, deref(ShortcutEvent).type],

0 commit comments

Comments
 (0)