Skip to content

[core] Implement a thread pool and call the CPython API on all threads within the same concurrency group #52575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions python/ray/tests/test_concurrency_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import ray
from ray._common.utils import get_or_create_event_loop
from ray._private.test_utils import run_string_as_driver
from ray._private.test_utils import run_string_as_driver, SignalActor


# This tests the methods are executed in the correct eventloop.
Expand Down Expand Up @@ -197,8 +197,12 @@ def get_thread_local(self) -> Tuple[Any, int]:

class TestThreadingLocalData:
"""
This test verifies that synchronous tasks can access thread local data
that was set by previous synchronous tasks.
This test verifies that synchronous tasks can access thread-local data that
was set by previous synchronous tasks when the concurrency group has only
one thread. For concurrency groups with multiple threads, it doesn't promise
access to the same thread-local data because Ray currently doesn't expose APIs
for users to specify which thread the task will be scheduled on in the same
concurrency group.
"""

def test_tasks_on_default_executor(self, ray_start_regular_shared):
Expand Down Expand Up @@ -236,6 +240,58 @@ def test_tasks_on_different_executors(self, ray_start_regular_shared):
assert value == "f2"


def test_multiple_threads_in_same_group(ray_start_regular_shared):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We can't use threading.enumerate() to check the number of threads because the threads are not launched by Python.

  2. The threads are visible for py-spy because it checks the information from OS.

"""
This test verifies that all threads in the same concurrency group are still
alive from the Python interpreter's perspective even if Ray tasks have finished, so that
thread-local data will not be garbage collected.
"""

@ray.remote
class Actor:
def __init__(self, signal: SignalActor, max_concurrency: int):
self._thread_local_data = threading.local()
self.signal = signal
self.thread_id_to_data = {}
self.max_concurrency = max_concurrency

def set_thread_local(self, value: int) -> int:
# If the thread-local data were garbage collected after the previous
# task on the same thread finished, `self.data` would be incremented
# more than once for the same thread.
assert not hasattr(self._thread_local_data, "value")
self._thread_local_data.value = value
self.thread_id_to_data[threading.current_thread().ident] = value
ray.get(self.signal.wait.remote())

def check_thread_local_data(self) -> bool:
assert len(self.thread_id_to_data) == self.max_concurrency
assert hasattr(self._thread_local_data, "value")
assert (
self._thread_local_data.value
== self.thread_id_to_data[threading.current_thread().ident]
)
ray.get(self.signal.wait.remote())

max_concurrency = 5
signal = SignalActor.remote()
a = Actor.options(max_concurrency=max_concurrency).remote(signal, max_concurrency)

refs = []
for i in range(max_concurrency):
refs.append(a.set_thread_local.remote(i))

ray.get(signal.send.remote())
ray.get(refs)

refs = []
for _ in range(max_concurrency):
refs.append(a.check_thread_local_data.remote())

ray.get(signal.send.remote())
ray.get(refs)


def test_invalid_concurrency_group():
"""Verify that when a concurrency group has max concurrency set to 0,
an error is raised when the actor is created. This test uses
Expand Down
1 change: 1 addition & 0 deletions src/ray/core_worker/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -576,5 +576,6 @@ ray_cc_library(
deps = [
"//src/ray/util:logging",
"@boost//:asio",
"@boost//:thread",
],
)
6 changes: 5 additions & 1 deletion src/ray/core_worker/fiber.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ class FiberState {
return true;
}

explicit FiberState(int max_concurrency)
explicit FiberState(
int max_concurrency,
// TODO(kevin85421): The language-specific callback function that
// initializes threads. It's not currently used in the async mode.
std::function<std::function<void()>()> initialize_thread_callback = nullptr)
: allocator_(kStackSize),
rate_limiter_(max_concurrency),
fiber_stopped_event_(std::make_shared<StdEvent>()) {
Expand Down
11 changes: 11 additions & 0 deletions src/ray/core_worker/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ ray_cc_test(
],
)

ray_cc_test(
name = "thread_pool_test",
srcs = ["thread_pool_test.cc"],
tags = ["team:core"],
deps = [
"//src/ray/core_worker:thread_pool",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

ray_cc_test(
name = "concurrency_group_manager_test",
srcs = ["concurrency_group_manager_test.cc"],
Expand Down
132 changes: 132 additions & 0 deletions src/ray/core_worker/test/thread_pool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ray/core_worker/transport/thread_pool.h"

#include <gtest/gtest.h>

#include <atomic>
#include <boost/thread/latch.hpp>
#include <future>

namespace ray {
namespace core {

TEST(BoundedExecutorTest, InitializeThreadCallbackAndReleaserAreCalled) {
constexpr int kNumThreads = 3;
std::atomic<int> init_count{0};
std::atomic<int> release_count{0};

// The callback increments init_count and returns a releaser that increments
// release_count.
auto initialize_thread_callback = [&]() {
init_count++;
return [&]() { release_count++; };
};

{
BoundedExecutor executor(kNumThreads, initialize_thread_callback);
// At this point, all threads should have called the initializer.
ASSERT_EQ(init_count.load(), kNumThreads);
ASSERT_EQ(release_count.load(), 0);

std::atomic<int> task_count{0};
auto callback = [&]() {
task_count++;
while (task_count.load() < kNumThreads) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
};

// Make sure all threads can run tasks.
for (int i = 0; i < kNumThreads; i++) {
executor.Post(callback);
}

// Join the pool, which should call the releasers.
executor.Join();
ASSERT_EQ(task_count.load(), kNumThreads);
}
// After join, all releasers should have been called.
ASSERT_EQ(release_count.load(), kNumThreads);
}

TEST(BoundedExecutorTest, InitializationTimeout) {
constexpr int kNumThreads = 3;

// Create a callback that will hang indefinitely to trigger the timeout
auto initialize_thread_callback = [&]() {
while (true) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
return nullptr;
};

// Verify that the constructor fails with the expected error message
EXPECT_DEATH(
BoundedExecutor executor(
kNumThreads, initialize_thread_callback, boost::chrono::milliseconds(10)),
"Failed to initialize threads in 10 milliseconds");
}

TEST(BoundedExecutorTest, PostBlockingIfFull) {
constexpr int kNumThreads = 3;
BoundedExecutor executor(kNumThreads);

boost::latch latch(kNumThreads);
std::atomic<bool> block{true};
auto callback = [&]() {
latch.count_down();
while (block.load()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
};

for (int i = 0; i < kNumThreads; i++) {
executor.Post(callback);
}
latch.wait();

// Submit a new task. It should not run immediately
// because the thread pool is full.
std::atomic<bool> running{false};
std::promise<void> promise;
std::future<void> future = promise.get_future();
executor.Post([&]() {
running = true;
promise.set_value();
});

// Make sure the task is not running yet after 50 ms.
std::this_thread::sleep_for(std::chrono::milliseconds(50));
ASSERT_FALSE(running.load());

// Unblock the threads. The task should run immediately.
block.store(false);

// Wait for the task with a timeout
auto status = future.wait_for(std::chrono::milliseconds(500));
ASSERT_EQ(status, std::future_status::ready) << "Task did not complete within timeout";
ASSERT_TRUE(running.load());

executor.Join();
}

} // namespace core
} // namespace ray

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
52 changes: 5 additions & 47 deletions src/ray/core_worker/transport/concurrency_group_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ ConcurrencyGroupManager<ExecutorType>::ConcurrencyGroupManager(
for (auto &group : concurrency_groups) {
const auto name = group.name;
const auto max_concurrency = group.max_concurrency;
auto executor = std::make_shared<ExecutorType>(max_concurrency);
executor_releasers_.push_back(InitializeExecutor(executor));
auto executor =
std::make_shared<ExecutorType>(max_concurrency, initialize_thread_callback_);
auto &fds = group.function_descriptors;
for (auto fd : fds) {
functions_to_executor_index_[fd->ToString()] = executor;
Expand All @@ -50,9 +50,8 @@ ConcurrencyGroupManager<ExecutorType>::ConcurrencyGroupManager(
// the thread pools instead of main thread.
if (ExecutorType::NeedDefaultExecutor(max_concurrency_for_default_concurrency_group,
!concurrency_groups.empty())) {
default_executor_ =
std::make_shared<ExecutorType>(max_concurrency_for_default_concurrency_group);
executor_releasers_.push_back(InitializeExecutor(default_executor_));
default_executor_ = std::make_shared<ExecutorType>(
max_concurrency_for_default_concurrency_group, initialize_thread_callback_);
}
}

Expand All @@ -62,7 +61,7 @@ std::shared_ptr<ExecutorType> ConcurrencyGroupManager<ExecutorType>::GetExecutor
if (concurrency_group_name == RayConfig::instance().system_concurrency_group_name() &&
name_to_executor_index_.find(concurrency_group_name) ==
name_to_executor_index_.end()) {
auto executor = std::make_shared<ExecutorType>(1);
auto executor = std::make_shared<ExecutorType>(1, initialize_thread_callback_);
name_to_executor_index_[concurrency_group_name] = executor;
}

Expand Down Expand Up @@ -91,50 +90,9 @@ std::shared_ptr<ExecutorType> ConcurrencyGroupManager<ExecutorType>::GetDefaultE
return default_executor_;
}

template <typename ExecutorType>
std::optional<std::function<void()>>
ConcurrencyGroupManager<ExecutorType>::InitializeExecutor(
std::shared_ptr<ExecutorType> executor) {
if (!initialize_thread_callback_) {
return std::nullopt;
}

if constexpr (std::is_same<ExecutorType, BoundedExecutor>::value) {
std::promise<void> init_promise;
auto init_future = init_promise.get_future();
auto initializer = initialize_thread_callback_;
std::function<void()> releaser;

executor->Post([&initializer, &init_promise, &releaser]() {
releaser = initializer();
init_promise.set_value();
});

// Wait for thread initialization to complete before executing any tasks in the
// executor.
init_future.wait();

return [executor, releaser]() {
std::promise<void> release_promise;
auto release_future = release_promise.get_future();
executor->Post([releaser, &release_promise]() {
releaser();
release_promise.set_value();
});
release_future.wait();
};
}
return std::nullopt;
}

/// Stop and join the executors that the this manager owns.
template <typename ExecutorType>
void ConcurrencyGroupManager<ExecutorType>::Stop() {
for (const auto &releaser : executor_releasers_) {
if (releaser.has_value()) {
releaser.value()();
}
}
if (default_executor_) {
RAY_LOG(DEBUG) << "Default executor is stopping.";
default_executor_->Stop();
Expand Down
11 changes: 0 additions & 11 deletions src/ray/core_worker/transport/concurrency_group_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ class ConcurrencyGroupManager final {
std::shared_ptr<ExecutorType> GetExecutor(const std::string &concurrency_group_name,
const ray::FunctionDescriptor &fd);

/// Initialize the executor for specific language runtime.
///
/// \param executor The executor to be initialized.

/// \return A function that will be called when destructing the executor.
std::optional<std::function<void()>> InitializeExecutor(
std::shared_ptr<ExecutorType> executor);

/// Get the default executor.
std::shared_ptr<ExecutorType> GetDefaultExecutor() const;

Expand All @@ -83,9 +75,6 @@ class ConcurrencyGroupManager final {
// The language-specific callback function that initializes threads.
std::function<std::function<void()>()> initialize_thread_callback_;

// A vector of language-specific functions used to release the executors.
std::vector<std::optional<std::function<void()>>> executor_releasers_;

friend class ConcurrencyGroupManagerTest;
};

Expand Down
Loading