Skip to content

[core] Implement a thread pool and call the CPython API on all threads within the same concurrency group #52575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/build.rayci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ steps:
tags:
- linux_wheels
- oss
instance_type: medium
instance_type: large
commands:
- bazel run //ci/ray_ci:build_in_docker -- wheel --build-type debug --upload
depends_on:
Expand Down
41 changes: 39 additions & 2 deletions python/ray/tests/test_concurrency_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,12 @@ def get_thread_local(self) -> Tuple[Any, int]:

class TestThreadingLocalData:
"""
This test verifies that synchronous tasks can access thread local data
that was set by previous synchronous tasks.
This test verifies that synchronous tasks can access thread-local data that
was set by previous synchronous tasks when the concurrency group has only
one thread. For concurrency groups with multiple threads, it doesn't promise
access to the same thread-local data because Ray currently doesn't expose APIs
for users to specify which thread the task will be scheduled on in the same
concurrency group.
"""

def test_tasks_on_default_executor(self, ray_start_regular_shared):
Expand Down Expand Up @@ -236,6 +240,39 @@ def test_tasks_on_different_executors(self, ray_start_regular_shared):
assert value == "f2"


def test_multiple_threads_in_same_group(ray_start_regular_shared):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We can't use threading.enumerate() to check the number of threads because the threads are not launched by Python.

  2. The threads are visible for py-spy because it checks the information from OS.

"""
This test verifies that all threads in the same concurrency group are still
alive from the Python interpreter's perspective even if Ray tasks have finished, so that
thread-local data will not be garbage collected.
"""

@ray.remote
class Actor:
def __init__(self):
self.data = 0
self._thread_local_data = threading.local()

def set_thread_local(self, value: Any) -> int:
# If the thread-local data were garbage collected after the previous
# task on the same thread finished, `self.data` would be incremented
# more than once for the same thread.
if not hasattr(self._thread_local_data, "value"):
self._thread_local_data.value = self.data
self.data += 1
assert self._thread_local_data.value <= self.data

def get_data(self) -> int:
return self.data

max_concurrency = 5
a = Actor.options(max_concurrency=max_concurrency).remote()
for _ in range(200):
for i in range(max_concurrency):
ray.get(a.set_thread_local.remote(i))
assert ray.get(a.get_data.remote()) == max_concurrency


def test_invalid_concurrency_group():
"""Verify that when a concurrency group has max concurrency set to 0,
an error is raised when the actor is created. This test uses
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/fiber.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class FiberState {
return true;
}

explicit FiberState(int max_concurrency)
explicit FiberState(int max_concurrency,
std::function<std::function<void()>()> = nullptr)
: allocator_(kStackSize),
rate_limiter_(max_concurrency),
fiber_stopped_event_(std::make_shared<StdEvent>()) {
Expand Down
11 changes: 11 additions & 0 deletions src/ray/core_worker/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ ray_cc_test(
],
)

ray_cc_test(
name = "thread_pool_test",
srcs = ["thread_pool_test.cc"],
tags = ["team:core"],
deps = [
"//src/ray/core_worker:thread_pool",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

ray_cc_test(
name = "concurrency_group_manager_test",
srcs = ["concurrency_group_manager_test.cc"],
Expand Down
61 changes: 61 additions & 0 deletions src/ray/core_worker/test/thread_pool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ray/core_worker/transport/thread_pool.h"

#include <gtest/gtest.h>

#include <atomic>
#include <future>

namespace ray {
namespace core {

TEST(BoundedExecutorTest, InitializeThreadCallbackAndReleaserAreCalled) {
constexpr int kNumThreads = 3;
std::atomic<int> init_count{0};
std::atomic<int> release_count{0};

// The callback increments init_count and returns a releaser that increments
// release_count.
auto initialize_thread_callback = [&]() {
init_count++;
return [&]() { release_count++; };
};

{
BoundedExecutor executor(kNumThreads, initialize_thread_callback);
// At this point, all threads should have called the initializer.
ASSERT_EQ(init_count.load(), kNumThreads);
ASSERT_EQ(release_count.load(), 0);

// Post a dummy task to ensure threads are running.
std::promise<void> p;
executor.Post([&] { p.set_value(); });
p.get_future().wait();

// Join the pool, which should call the releasers.
executor.Join();
}
// After join, all releasers should have been called.
ASSERT_EQ(release_count.load(), kNumThreads);
}

} // namespace core
} // namespace ray

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
52 changes: 5 additions & 47 deletions src/ray/core_worker/transport/concurrency_group_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ ConcurrencyGroupManager<ExecutorType>::ConcurrencyGroupManager(
for (auto &group : concurrency_groups) {
const auto name = group.name;
const auto max_concurrency = group.max_concurrency;
auto executor = std::make_shared<ExecutorType>(max_concurrency);
executor_releasers_.push_back(InitializeExecutor(executor));
auto executor =
std::make_shared<ExecutorType>(max_concurrency, initialize_thread_callback_);
auto &fds = group.function_descriptors;
for (auto fd : fds) {
functions_to_executor_index_[fd->ToString()] = executor;
Expand All @@ -50,9 +50,8 @@ ConcurrencyGroupManager<ExecutorType>::ConcurrencyGroupManager(
// the thread pools instead of main thread.
if (ExecutorType::NeedDefaultExecutor(max_concurrency_for_default_concurrency_group,
!concurrency_groups.empty())) {
default_executor_ =
std::make_shared<ExecutorType>(max_concurrency_for_default_concurrency_group);
executor_releasers_.push_back(InitializeExecutor(default_executor_));
default_executor_ = std::make_shared<ExecutorType>(
max_concurrency_for_default_concurrency_group, initialize_thread_callback_);
}
}

Expand All @@ -62,7 +61,7 @@ std::shared_ptr<ExecutorType> ConcurrencyGroupManager<ExecutorType>::GetExecutor
if (concurrency_group_name == RayConfig::instance().system_concurrency_group_name() &&
name_to_executor_index_.find(concurrency_group_name) ==
name_to_executor_index_.end()) {
auto executor = std::make_shared<ExecutorType>(1);
auto executor = std::make_shared<ExecutorType>(1, initialize_thread_callback_);
name_to_executor_index_[concurrency_group_name] = executor;
}

Expand Down Expand Up @@ -91,50 +90,9 @@ std::shared_ptr<ExecutorType> ConcurrencyGroupManager<ExecutorType>::GetDefaultE
return default_executor_;
}

template <typename ExecutorType>
std::optional<std::function<void()>>
ConcurrencyGroupManager<ExecutorType>::InitializeExecutor(
std::shared_ptr<ExecutorType> executor) {
if (!initialize_thread_callback_) {
return std::nullopt;
}

if constexpr (std::is_same<ExecutorType, BoundedExecutor>::value) {
std::promise<void> init_promise;
auto init_future = init_promise.get_future();
auto initializer = initialize_thread_callback_;
std::function<void()> releaser;

executor->Post([&initializer, &init_promise, &releaser]() {
releaser = initializer();
init_promise.set_value();
});

// Wait for thread initialization to complete before executing any tasks in the
// executor.
init_future.wait();

return [executor, releaser]() {
std::promise<void> release_promise;
auto release_future = release_promise.get_future();
executor->Post([releaser, &release_promise]() {
releaser();
release_promise.set_value();
});
release_future.wait();
};
}
return std::nullopt;
}

/// Stop and join the executors that the this manager owns.
template <typename ExecutorType>
void ConcurrencyGroupManager<ExecutorType>::Stop() {
for (const auto &releaser : executor_releasers_) {
if (releaser.has_value()) {
releaser.value()();
}
}
if (default_executor_) {
RAY_LOG(DEBUG) << "Default executor is stopping.";
default_executor_->Stop();
Expand Down
11 changes: 0 additions & 11 deletions src/ray/core_worker/transport/concurrency_group_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ class ConcurrencyGroupManager final {
std::shared_ptr<ExecutorType> GetExecutor(const std::string &concurrency_group_name,
const ray::FunctionDescriptor &fd);

/// Initialize the executor for specific language runtime.
///
/// \param executor The executor to be initialized.

/// \return A function that will be called when destructing the executor.
std::optional<std::function<void()>> InitializeExecutor(
std::shared_ptr<ExecutorType> executor);

/// Get the default executor.
std::shared_ptr<ExecutorType> GetDefaultExecutor() const;

Expand All @@ -83,9 +75,6 @@ class ConcurrencyGroupManager final {
// The language-specific callback function that initializes threads.
std::function<std::function<void()>()> initialize_thread_callback_;

// A vector of language-specific functions used to release the executors.
std::vector<std::optional<std::function<void()>>> executor_releasers_;

friend class ConcurrencyGroupManagerTest;
};

Expand Down
45 changes: 41 additions & 4 deletions src/ray/core_worker/transport/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,58 @@

#include "ray/core_worker/transport/thread_pool.h"

#include <boost/asio/post.hpp>
#include <future>
#include <memory>
#include <utility>

namespace ray {
namespace core {

BoundedExecutor::BoundedExecutor(int max_concurrency) {
BoundedExecutor::BoundedExecutor(
int max_concurrency,
std::function<std::function<void()>()> initialize_thread_callback)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add cpp unit tests to check the initialize_thread_callback behavior:

  1. It's called when constructing the thread pool
  2. The releaser is called when joining

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

: work_guard_(boost::asio::make_work_guard(io_context_)) {
RAY_CHECK(max_concurrency > 0) << "max_concurrency must be greater than 0";
pool_ = std::make_unique<boost::asio::thread_pool>(max_concurrency);
threads_.reserve(max_concurrency);
for (int i = 0; i < max_concurrency; i++) {
std::promise<void> init_promise;
auto init_future = init_promise.get_future();
threads_.emplace_back([this, initialize_thread_callback, &init_promise]() {
std::function<void()> releaser;
if (initialize_thread_callback) {
releaser = initialize_thread_callback();
}
init_promise.set_value();
// `io_context_.run()` will block until `work_guard_.reset()` is called.
io_context_.run();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will return after work_guard_.reset();.

if (releaser) {
releaser();
}
});
init_future.wait();
}
}

void BoundedExecutor::Post(std::function<void()> fn) {
boost::asio::post(io_context_, std::move(fn));
}

/// Stop the thread pool.
void BoundedExecutor::Stop() { pool_->stop(); }
void BoundedExecutor::Stop() {
work_guard_.reset();
io_context_.stop();
}

/// Join the thread pool.
void BoundedExecutor::Join() { pool_->join(); }
void BoundedExecutor::Join() {
work_guard_.reset();
Copy link
Member Author

@kevin85421 kevin85421 May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintain the previous behavior. We can’t assume that Join will always be called after Stop; therefore, we need to reset work_guard_ here.

It's fine to call work_guard_.reset() twice. The second one will not do anything.

for (auto &thread : threads_) {
if (thread.joinable()) {
thread.join();
}
}
}

} // namespace core
} // namespace ray
17 changes: 11 additions & 6 deletions src/ray/core_worker/transport/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#pragma once

#include <boost/asio/post.hpp>
#include <boost/asio/thread_pool.hpp>
#include <boost/asio/executor_work_guard.hpp>
#include <boost/asio/io_context.hpp>
#include <functional>
#include <memory>
#include <thread>
#include <utility>
#include <vector>

#include "ray/util/logging.h"

Expand All @@ -37,10 +39,12 @@ class BoundedExecutor {
return max_concurrency_in_default_group > 1 || has_other_concurrency_groups;
}

explicit BoundedExecutor(int max_concurrency);
explicit BoundedExecutor(
int max_concurrency,
std::function<std::function<void()>()> initialize_thread_callback = nullptr);

/// Posts work to the pool
void Post(std::function<void()> fn) { boost::asio::post(*pool_, std::move(fn)); }
void Post(std::function<void()> fn);

/// Stop the thread pool.
void Stop();
Expand All @@ -49,8 +53,9 @@ class BoundedExecutor {
void Join();

private:
/// The underlying thread pool for running tasks.
std::unique_ptr<boost::asio::thread_pool> pool_;
boost::asio::io_context io_context_;
boost::asio::executor_work_guard<boost::asio::io_context::executor_type> work_guard_;
std::vector<std::thread> threads_;
};

} // namespace core
Expand Down