Skip to content

Commit 6a50837

Browse files
committed
[data] support ray_remote_args and ray_remote_args_fn for preprocessor
Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent 86123ac commit 6a50837

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

python/ray/data/preprocessor.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pickle
55
import warnings
66
from enum import Enum
7-
from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional
7+
from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional, Callable
88

99
from ray.air.util.data_batch_conversion import BatchFormat
1010
from ray.util.annotations import DeveloperAPI, PublicAPI
@@ -47,6 +47,24 @@ class Preprocessor(abc.ABC):
4747
implemented method.
4848
"""
4949

50+
def __init__(
51+
self,
52+
ray_remote_args: Optional[Dict[str, Any]] = None,
53+
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
54+
):
55+
"""
56+
Args:
57+
ray_remote_args: Args to provide to :func:`ray.remote`.
58+
ray_remote_args_fn: A function that returns a dictionary of remote args
59+
passed to each map worker. The purpose of this argument is to generate
60+
dynamic arguments for each actor/task, and will be called each time
61+
prior to initializing the worker. Args returned from this dict will
62+
always override the args in ``ray_remote_args``. Note: this is an
63+
advanced, experimental feature.
64+
"""
65+
self._ray_remote_args = ray_remote_args
66+
self._ray_remote_args_fn = ray_remote_args_fn
67+
5068
class FitStatus(str, Enum):
5169
"""The fit status of preprocessor."""
5270

@@ -225,6 +243,11 @@ def _transform(self, ds: "Dataset") -> "Dataset":
225243
# Our user-facing batch format should only be pandas or NumPy, other
226244
# formats {arrow, simple} are internal.
227245
kwargs = self._get_transform_config()
246+
if self._ray_remote_args is not None:
247+
kwargs = dict(kwargs, **self._ray_remote_args)
248+
if self._ray_remote_args_fn is not None:
249+
kwargs["ray_remote_args_fn"] = self._ray_remote_args_fn
250+
228251
if transform_type == BatchFormat.PANDAS:
229252
return ds.map_batches(
230253
self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs

python/ray/data/preprocessors/tokenizer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Optional
1+
from typing import Callable, List, Optional, Dict, Any
22

33
import pandas as pd
44

@@ -59,6 +59,13 @@ class Tokenizer(Preprocessor):
5959
columns will be the same as the input columns. If not None, the length of
6060
``output_columns`` must match the length of ``columns``, othwerwise an error
6161
will be raised.
62+
ray_remote_args: Args to provide to :func:`ray.remote`.
63+
ray_remote_args_fn: A function that returns a dictionary of remote args
64+
passed to each map worker. The purpose of this argument is to generate
65+
dynamic arguments for each actor/task, and will be called each time
66+
prior to initializing the worker. Args returned from this dict will
67+
always override the args in ``ray_remote_args``. Note: this is an
68+
advanced, experimental feature.
6269
"""
6370

6471
_is_fittable = False
@@ -68,7 +75,13 @@ def __init__(
6875
columns: List[str],
6976
tokenization_fn: Optional[Callable[[str], List[str]]] = None,
7077
output_columns: Optional[List[str]] = None,
78+
ray_remote_args: Optional[Dict[str, Any]] = None,
79+
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
7180
):
81+
super().__init__(
82+
ray_remote_args=ray_remote_args,
83+
ray_remote_args_fn=ray_remote_args_fn,
84+
)
7285
self.columns = columns
7386
# TODO(matt): Add a more robust default tokenizer.
7487
self.tokenization_fn = tokenization_fn or simple_split_tokenizer

0 commit comments

Comments
 (0)