@@ -66,44 +66,118 @@ auto spawn(F&& f) {
6666 return runtime ().spawn (std::forward<F>(f));
6767}
6868
69- // / Queues the given function to run in the main thread as soon as possible
70- // / and waits for it to complete. Returns null if the function failed to send the result.
71- // / (although that usually cannot happen in practice)
72- template <typename T = void > requires (!std::is_void_v<T>)
73- arc::Future<std::optional<T>> waitForMainThread(Function<T()> func) {
74- auto [tx, rx] = arc::oneshot::channel<T>();
75- auto token = std::make_shared<arc::CancellationToken>();
69+ template <
70+ typename T,
71+ typename NonVoidT = std::conditional_t <std::is_void_v<T>, std::monostate, T>,
72+ typename PollOut = std::conditional_t <std::is_void_v<T>, bool , std::optional<NonVoidT>>
73+ >
74+ struct WaitForMainAwaiter : arc::Pollable<WaitForMainAwaiter<T>, PollOut> {
75+ template <typename F> requires (!std::is_same_v<std::decay_t <F>, WaitForMainAwaiter>)
76+ explicit WaitForMainAwaiter(F&& func) {
77+ m_state = std::make_shared<std::atomic<State>>(State::Pending);
78+ auto [tx, rx] = arc::oneshot::channel<NonVoidT>();
79+ m_receiver.emplace (std::move (rx));
80+ m_recvAwaiter.emplace (m_receiver->recv ());
81+
82+ geode::queueInMainThread ([state = m_state, func = std::forward<F>(func), tx = std::move (tx)] mutable {
83+ auto expected = State::Pending;
84+ if (!state->compare_exchange_strong (expected, State::Running, std::memory_order::acq_rel)) {
85+ // cancelled before the function started running, simply exit
86+ return ;
87+ }
7688
77- auto _ = arc::scopeDtor ([&] {
78- token->cancel ();
79- });
89+ auto complete = [&]<typename X>(X&& val) {
90+ // the state must be either Running or RunningCancelled, depending on this we decide whether to post the result or not
91+ bool shouldPost = State::Running == state->exchange (State::Completed, std::memory_order::acq_rel);
92+
93+ if (shouldPost) {
94+ (void ) tx.send (std::forward<X>(val));
95+ } else {
96+ state->notify_one ();
97+ }
98+ };
99+
100+ if constexpr (std::is_void_v<T>) {
101+ func ();
102+ complete (std::monostate{});
103+ } else {
104+ complete (func ());
105+ }
106+ });
107+ }
80108
81- geode::queueInMainThread ([func = std::move (func), tx = std::move (tx), token] mutable {
82- if (token->isCancelled ()) return ;
83- (void ) tx.send (func ());
84- });
109+ ~WaitForMainAwaiter () {
110+ if (!m_state || !m_receiver) return ;
111+
112+ auto state = m_state->load (std::memory_order::acquire);
113+ while (true ) {
114+ switch (state) {
115+ case State::Pending: {
116+ // function hasn't ran yet, all we have to do is cancel it
117+ if (m_state->compare_exchange_weak (state, State::Completed, std::memory_order::acq_rel)) {
118+ return ;
119+ }
120+ } break ;
121+
122+ case State::Running: {
123+ // currently running, we need to wait for it to finish
124+ // the function may have captured local environment, prevent UB by waiting until it finishes before we destroy anything
125+ if (m_state->compare_exchange_weak (state, State::RunningCancelled, std::memory_order::acq_rel)) {
126+ m_state->wait (State::RunningCancelled, std::memory_order::acquire);
127+ return ;
128+ }
129+ } break ;
130+
131+ // nothing to do if completed, and all other states are impossible here
132+ case State::Completed:
133+ default : return ;
134+ }
135+ }
136+ }
85137
86- co_return (co_await rx.recv ()).ok ();
87- }
138+ WaitForMainAwaiter (WaitForMainAwaiter const &) = delete ;
139+ WaitForMainAwaiter& operator =(WaitForMainAwaiter const &) = delete ;
140+ WaitForMainAwaiter (WaitForMainAwaiter&&) = default ;
141+ WaitForMainAwaiter& operator =(WaitForMainAwaiter&&) = delete ;
142+
143+ std::optional<PollOut> poll (arc::Context& cx) {
144+ auto pres = m_recvAwaiter->poll (cx);
145+ if (!pres) return std::nullopt ;
146+
147+ // res is RecvResult<T>, return Some(nullopt/false) if it failed
148+ auto res = std::move (*pres);
149+
150+ if constexpr (std::is_void_v<T>) {
151+ return std::optional{res.isOk ()};
152+ } else {
153+ return std::move (res).ok ();
154+ }
155+ }
156+
157+ private:
158+ enum class State : uint8_t {
159+ // / The function has been queued to run, but is not yet running
160+ Pending,
161+ // / The function is currently running in the main thread
162+ Running,
163+ // / The function is still running in the main thread, but it has been cancelled
164+ RunningCancelled,
165+ // / The function either completed and posted the result, or was cancelled and is not currently running
166+ Completed,
167+ };
168+
169+ std::shared_ptr<std::atomic<State>> m_state;
170+ std::optional<arc::oneshot::Receiver<NonVoidT>> m_receiver;
171+ std::optional<arc::oneshot::RecvAwaiter<NonVoidT>> m_recvAwaiter;
172+ };
88173
89174// / Queues the given function to run in the main thread as soon as possible
90- // / and waits for it to complete. Returns false if the function failed to complete (e.g. due to exception)
91- template <typename T = void > requires (std::is_void_v<T>)
92- arc::Future<bool> waitForMainThread(Function<void ()> func) {
93- auto [tx, rx] = arc::oneshot::channel<std::monostate>();
94- auto token = std::make_shared<arc::CancellationToken>();
95-
96- auto _ = arc::scopeDtor ([&] {
97- token->cancel ();
98- });
99-
100- geode::queueInMainThread ([func = std::move (func), tx = std::move (tx), token] mutable {
101- if (token->isCancelled ()) return ;
102- func ();
103- (void ) tx.send ({});
104- });
105-
106- co_return (co_await rx.recv ()).isOk ();
175+ // / and waits for it to complete. Returns null/false if the function failed to send the result.
176+ // / (although that usually cannot happen in practice)
177+ // / This pollable is cancel-safe and won't be destroyed until the function finishes running or is confirmed to be aborted.
178+ template <typename T = void , typename F>
179+ auto waitForMainThread (F&& func) {
180+ return WaitForMainAwaiter<T>(std::forward<F>(func));
107181}
108182
109183// / Allows an async task to be spawned and then automatically aborted when the holder goes out of scope.
0 commit comments