Skip to content

Commit 4b851bd

Browse files
committed
make waitForMainThread fully cancel safe
1 parent 9b4fd53 commit 4b851bd

File tree

1 file changed

+107
-33
lines changed

1 file changed

+107
-33
lines changed

loader/include/Geode/utils/async.hpp

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,44 +66,118 @@ auto spawn(F&& f) {
6666
return runtime().spawn(std::forward<F>(f));
6767
}
6868

69-
/// Queues the given function to run in the main thread as soon as possible
70-
/// and waits for it to complete. Returns null if the function failed to send the result.
71-
/// (although that usually cannot happen in practice)
72-
template <typename T = void> requires (!std::is_void_v<T>)
73-
arc::Future<std::optional<T>> waitForMainThread(Function<T()> func) {
74-
auto [tx, rx] = arc::oneshot::channel<T>();
75-
auto token = std::make_shared<arc::CancellationToken>();
69+
template <
70+
typename T,
71+
typename NonVoidT = std::conditional_t<std::is_void_v<T>, std::monostate, T>,
72+
typename PollOut = std::conditional_t<std::is_void_v<T>, bool, std::optional<NonVoidT>>
73+
>
74+
struct WaitForMainAwaiter : arc::Pollable<WaitForMainAwaiter<T>, PollOut> {
75+
template <typename F> requires (!std::is_same_v<std::decay_t<F>, WaitForMainAwaiter>)
76+
explicit WaitForMainAwaiter(F&& func) {
77+
m_state = std::make_shared<std::atomic<State>>(State::Pending);
78+
auto [tx, rx] = arc::oneshot::channel<NonVoidT>();
79+
m_receiver.emplace(std::move(rx));
80+
m_recvAwaiter.emplace(m_receiver->recv());
81+
82+
geode::queueInMainThread([state = m_state, func = std::forward<F>(func), tx = std::move(tx)] mutable {
83+
auto expected = State::Pending;
84+
if (!state->compare_exchange_strong(expected, State::Running, std::memory_order::acq_rel)) {
85+
// cancelled before the function started running, simply exit
86+
return;
87+
}
7688

77-
auto _ = arc::scopeDtor([&] {
78-
token->cancel();
79-
});
89+
auto complete = [&]<typename X>(X&& val) {
90+
// the state must be either Running or RunningCancelled, depending on this we decide whether to post the result or not
91+
bool shouldPost = State::Running == state->exchange(State::Completed, std::memory_order::acq_rel);
92+
93+
if (shouldPost) {
94+
(void) tx.send(std::forward<X>(val));
95+
} else {
96+
state->notify_one();
97+
}
98+
};
99+
100+
if constexpr (std::is_void_v<T>) {
101+
func();
102+
complete(std::monostate{});
103+
} else {
104+
complete(func());
105+
}
106+
});
107+
}
80108

81-
geode::queueInMainThread([func = std::move(func), tx = std::move(tx), token] mutable {
82-
if (token->isCancelled()) return;
83-
(void) tx.send(func());
84-
});
109+
~WaitForMainAwaiter() {
110+
if (!m_state || !m_receiver) return;
111+
112+
auto state = m_state->load(std::memory_order::acquire);
113+
while (true) {
114+
switch (state) {
115+
case State::Pending: {
116+
// function hasn't ran yet, all we have to do is cancel it
117+
if (m_state->compare_exchange_weak(state, State::Completed, std::memory_order::acq_rel)) {
118+
return;
119+
}
120+
} break;
121+
122+
case State::Running: {
123+
// currently running, we need to wait for it to finish
124+
// the function may have captured local environment, prevent UB by waiting until it finishes before we destroy anything
125+
if (m_state->compare_exchange_weak(state, State::RunningCancelled, std::memory_order::acq_rel)) {
126+
m_state->wait(State::RunningCancelled, std::memory_order::acquire);
127+
return;
128+
}
129+
} break;
130+
131+
// nothing to do if completed, and all other states are impossible here
132+
case State::Completed:
133+
default: return;
134+
}
135+
}
136+
}
85137

86-
co_return (co_await rx.recv()).ok();
87-
}
138+
WaitForMainAwaiter(WaitForMainAwaiter const&) = delete;
139+
WaitForMainAwaiter& operator=(WaitForMainAwaiter const&) = delete;
140+
WaitForMainAwaiter(WaitForMainAwaiter&&) = default;
141+
WaitForMainAwaiter& operator=(WaitForMainAwaiter&&) = delete;
142+
143+
std::optional<PollOut> poll(arc::Context& cx) {
144+
auto pres = m_recvAwaiter->poll(cx);
145+
if (!pres) return std::nullopt;
146+
147+
// res is RecvResult<T>, return Some(nullopt/false) if it failed
148+
auto res = std::move(*pres);
149+
150+
if constexpr (std::is_void_v<T>) {
151+
return std::optional{res.isOk()};
152+
} else {
153+
return std::move(res).ok();
154+
}
155+
}
156+
157+
private:
158+
enum class State : uint8_t {
159+
/// The function has been queued to run, but is not yet running
160+
Pending,
161+
/// The function is currently running in the main thread
162+
Running,
163+
/// The function is still running in the main thread, but it has been cancelled
164+
RunningCancelled,
165+
/// The function either completed and posted the result, or was cancelled and is not currently running
166+
Completed,
167+
};
168+
169+
std::shared_ptr<std::atomic<State>> m_state;
170+
std::optional<arc::oneshot::Receiver<NonVoidT>> m_receiver;
171+
std::optional<arc::oneshot::RecvAwaiter<NonVoidT>> m_recvAwaiter;
172+
};
88173

89174
/// Queues the given function to run in the main thread as soon as possible
90-
/// and waits for it to complete. Returns false if the function failed to complete (e.g. due to exception)
91-
template <typename T = void> requires (std::is_void_v<T>)
92-
arc::Future<bool> waitForMainThread(Function<void()> func) {
93-
auto [tx, rx] = arc::oneshot::channel<std::monostate>();
94-
auto token = std::make_shared<arc::CancellationToken>();
95-
96-
auto _ = arc::scopeDtor([&] {
97-
token->cancel();
98-
});
99-
100-
geode::queueInMainThread([func = std::move(func), tx = std::move(tx), token] mutable {
101-
if (token->isCancelled()) return;
102-
func();
103-
(void) tx.send({});
104-
});
105-
106-
co_return (co_await rx.recv()).isOk();
175+
/// and waits for it to complete. Returns null/false if the function failed to send the result.
176+
/// (although that usually cannot happen in practice)
177+
/// This pollable is cancel-safe and won't be destroyed until the function finishes running or is confirmed to be aborted.
178+
template <typename T = void, typename F>
179+
auto waitForMainThread(F&& func) {
180+
return WaitForMainAwaiter<T>(std::forward<F>(func));
107181
}
108182

109183
/// Allows an async task to be spawned and then automatically aborted when the holder goes out of scope.

0 commit comments

Comments
 (0)