|
4 | 4 | import pickle
|
5 | 5 | import warnings
|
6 | 6 | from enum import Enum
|
7 |
| -from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional |
| 7 | +from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional, Callable |
8 | 8 |
|
9 | 9 | from ray.air.util.data_batch_conversion import BatchFormat
|
10 | 10 | from ray.util.annotations import DeveloperAPI, PublicAPI
|
@@ -47,6 +47,24 @@ class Preprocessor(abc.ABC):
|
47 | 47 | implemented method.
|
48 | 48 | """
|
49 | 49 |
|
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + ray_remote_args: Optional[Dict[str, Any]] = None, |
| 53 | + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, |
| 54 | + ): |
| 55 | + """ |
| 56 | + Args: |
| 57 | + ray_remote_args: Args to provide to :func:`ray.remote`. |
| 58 | + ray_remote_args_fn: A function that returns a dictionary of remote args |
| 59 | + passed to each map worker. The purpose of this argument is to generate |
| 60 | + dynamic arguments for each actor/task, and will be called each time |
| 61 | + prior to initializing the worker. Args returned from this dict will |
| 62 | + always override the args in ``ray_remote_args``. Note: this is an |
| 63 | + advanced, experimental feature. |
| 64 | + """ |
| 65 | + self._ray_remote_args = ray_remote_args |
| 66 | + self._ray_remote_args_fn = ray_remote_args_fn |
| 67 | + |
50 | 68 | class FitStatus(str, Enum):
|
51 | 69 | """The fit status of preprocessor."""
|
52 | 70 |
|
@@ -225,6 +243,11 @@ def _transform(self, ds: "Dataset") -> "Dataset":
|
225 | 243 | # Our user-facing batch format should only be pandas or NumPy, other
|
226 | 244 | # formats {arrow, simple} are internal.
|
227 | 245 | kwargs = self._get_transform_config()
|
| 246 | + if self._ray_remote_args is not None: |
| 247 | + kwargs = dict(kwargs, **self._ray_remote_args) |
| 248 | + if self._ray_remote_args_fn is not None: |
| 249 | + kwargs["ray_remote_args_fn"] = self._ray_remote_args_fn |
| 250 | + |
228 | 251 | if transform_type == BatchFormat.PANDAS:
|
229 | 252 | return ds.map_batches(
|
230 | 253 | self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs
|
|
0 commit comments