Skip to content

Commit d52a9fa

Browse files
committed
[data] use nums_cpus, memory, batch_size, concurrency
Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent 6c9e8bf commit d52a9fa

File tree

3 files changed

+47
-42
lines changed

3 files changed

+47
-42
lines changed

python/ray/data/preprocessor.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,17 @@
44
import pickle
55
import warnings
66
from enum import Enum
7-
from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional, Callable
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Dict,
11+
Union,
12+
List,
13+
Optional,
14+
Callable,
15+
Literal,
16+
Tuple,
17+
)
818

919
from ray.air.util.data_batch_conversion import BatchFormat
1020
from ray.util.annotations import DeveloperAPI, PublicAPI
@@ -49,21 +59,22 @@ class Preprocessor(abc.ABC):
4959

5060
def __init__(
5161
self,
52-
ray_remote_args: Optional[Dict[str, Any]] = None,
53-
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
62+
num_cpus: Optional[float] = None,
63+
memory: Optional[float] = None,
64+
batch_size: Union[int, None, Literal["default"]] = None,
65+
concurrency: Optional[int] = None,
5466
):
5567
"""
5668
Args:
57-
ray_remote_args: Args to provide to :func:`ray.remote`.
58-
ray_remote_args_fn: A function that returns a dictionary of remote args
59-
passed to each map worker. The purpose of this argument is to generate
60-
dynamic arguments for each actor/task, and will be called each time
61-
prior to initializing the worker. Args returned from this dict will
62-
always override the args in ``ray_remote_args``. Note: this is an
63-
advanced, experimental feature.
69+
num_cpus: The number of CPUs to reserve for each parallel map worker.
70+
memory: The heap memory in bytes to reserve for each parallel map worker.
71+
batch_size: The maximum number of rows to return.
72+
concurrency: The maximum number of Ray workers to use concurrently.
6473
"""
65-
self._ray_remote_args = ray_remote_args
66-
self._ray_remote_args_fn = ray_remote_args_fn
74+
self._num_cpus = num_cpus
75+
self._memory = memory
76+
self._batch_size = batch_size
77+
self._concurrency = concurrency
6778

6879
class FitStatus(str, Enum):
6980
"""The fit status of preprocessor."""
@@ -243,10 +254,10 @@ def _transform(self, ds: "Dataset") -> "Dataset":
243254
# Our user-facing batch format should only be pandas or NumPy, other
244255
# formats {arrow, simple} are internal.
245256
kwargs = self._get_transform_config()
246-
if self._ray_remote_args is not None:
247-
kwargs = dict(kwargs, **self._ray_remote_args)
248-
if self._ray_remote_args_fn is not None:
249-
kwargs["ray_remote_args_fn"] = self._ray_remote_args_fn
257+
kwargs["num_cpus"] = self._num_cpus
258+
kwargs["memory"] = self._memory
259+
kwargs["batch_size"] = self._batch_size
260+
kwargs["concurrency"] = self._concurrency
250261

251262
if transform_type == BatchFormat.PANDAS:
252263
return ds.map_batches(

python/ray/data/preprocessors/tokenizer.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Optional, Dict, Any
1+
from typing import Callable, List, Optional, Literal, Union
22

33
import pandas as pd
44

@@ -59,13 +59,10 @@ class Tokenizer(Preprocessor):
5959
columns will be the same as the input columns. If not None, the length of
6060
``output_columns`` must match the length of ``columns``, othwerwise an error
6161
will be raised.
62-
ray_remote_args: Args to provide to :func:`ray.remote`.
63-
ray_remote_args_fn: A function that returns a dictionary of remote args
64-
passed to each map worker. The purpose of this argument is to generate
65-
dynamic arguments for each actor/task, and will be called each time
66-
prior to initializing the worker. Args returned from this dict will
67-
always override the args in ``ray_remote_args``. Note: this is an
68-
advanced, experimental feature.
62+
num_cpus: The number of CPUs to reserve for each parallel map worker.
63+
memory: The heap memory in bytes to reserve for each parallel map worker.
64+
batch_size: The maximum number of rows to return.
65+
concurrency: The maximum number of Ray workers to use concurrently.
6966
"""
7067

7168
_is_fittable = False
@@ -75,12 +72,17 @@ def __init__(
7572
columns: List[str],
7673
tokenization_fn: Optional[Callable[[str], List[str]]] = None,
7774
output_columns: Optional[List[str]] = None,
78-
ray_remote_args: Optional[Dict[str, Any]] = None,
79-
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
75+
*,
76+
num_cpus: Optional[float] = None,
77+
memory: Optional[float] = None,
78+
batch_size: Union[int, None, Literal["default"]] = None,
79+
concurrency: Optional[int] = None,
8080
):
8181
super().__init__(
82-
ray_remote_args=ray_remote_args,
83-
ray_remote_args_fn=ray_remote_args_fn,
82+
num_cpus=num_cpus,
83+
memory=memory,
84+
batch_size=batch_size,
85+
concurrency=concurrency,
8486
)
8587
self.columns = columns
8688
# TODO(matt): Add a more robust default tokenizer.

python/ray/data/tests/preprocessors/test_preprocessors.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,9 @@ def test_fit_twice(mocked_warn):
165165
mocked_warn.assert_called_once_with(msg)
166166

167167

168-
def test_ray_remote_args_and_fn():
168+
def test_initialization_parameters():
169169
batch_size = 2
170170

171-
ray_remote_args = {"num_cpus": 2}
172-
173-
def func(df):
174-
import os
175-
176-
df["value"][:] = int(os.environ["__MY_TEST__"])
177-
return df
178-
179171
class DummyPreprocessor(Preprocessor):
180172
_is_fittable = False
181173

@@ -185,23 +177,23 @@ def _get_transform_config(self):
185177
def _transform_numpy(self, data):
186178
assert (
187179
ray.get_runtime_context().get_assigned_resources()["CPU"]
188-
== ray_remote_args["num_cpus"]
180+
== self._num_cpus
189181
)
190182
assert len(data["value"]) == batch_size
191-
func(data)
192183
return data
193184

194185
def _determine_transform_to_use(self):
195186
return "numpy"
196187

197188
prep = DummyPreprocessor(
198-
ray_remote_args=ray_remote_args,
199-
ray_remote_args_fn=lambda: {"runtime_env": {"env_vars": {"__MY_TEST__": "69"}}},
189+
num_cpus=2,
190+
concurrency=2,
191+
batch_size=batch_size,
200192
)
201193
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
202194
ds = prep.transform(ds)
203195

204-
assert sorted([x["value"] for x in ds.take(5)]) == [69, 69, 69, 69, 69]
196+
assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4]
205197

206198

207199
def test_transform_config():

0 commit comments

Comments
 (0)