Skip to content

[Data] support num_cpus, memory, concurrency, batch_size for preprocess #52574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
64 changes: 57 additions & 7 deletions python/ray/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
self._fitted = True
return fitted_ds

def fit_transform(self, ds: "Dataset") -> "Dataset":
def fit_transform(
self,
ds: "Dataset",
*,
transform_num_cpus: Optional[float] = None,
transform_memory: Optional[float] = None,
transform_batch_size: Optional[int] = None,
transform_concurrency: Optional[int] = None,
) -> "Dataset":
"""Fit this Preprocessor to the Dataset and then transform the Dataset.

Calling it more than once will overwrite all previously fitted state:
Expand All @@ -127,18 +135,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":

Args:
ds: Input Dataset.
transform_num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
transform_memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
transform_batch_size: [experimental] The maximum number of rows to return.
transform_concurrency: [experimental] The maximum number of Ray workers to use concurrently.

Returns:
ray.data.Dataset: The transformed Dataset.
"""
self.fit(ds)
return self.transform(ds)
return self.transform(
ds,
num_cpus=transform_num_cpus,
memory=transform_memory,
batch_size=transform_batch_size,
concurrency=transform_concurrency,
)

def transform(self, ds: "Dataset") -> "Dataset":
def transform(
self,
ds: "Dataset",
*,
batch_size: Optional[int] = None,
num_cpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[int] = None,
) -> "Dataset":
"""Transform the given dataset.

Args:
ds: Input Dataset.
batch_size: [experimental] Advanced configuration for adjusting input size for each worker.
num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
concurrency: [experimental] The maximum number of Ray workers to use concurrently.

Returns:
ray.data.Dataset: The transformed Dataset.
Expand All @@ -155,7 +185,13 @@ def transform(self, ds: "Dataset") -> "Dataset":
"`fit` must be called before `transform`, "
"or simply use fit_transform() to run both steps"
)
transformed_ds = self._transform(ds)
transformed_ds = self._transform(
ds,
batch_size=batch_size,
num_cpus=num_cpus,
memory=memory,
concurrency=concurrency,
)
return transformed_ds

def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
Expand Down Expand Up @@ -217,14 +253,28 @@ def _determine_transform_to_use(self) -> BatchFormat:
"for Preprocessor transforms."
)

def _transform(self, ds: "Dataset") -> "Dataset":
# TODO(matt): Expose `batch_size` or similar configurability.
# The default may be too small for some datasets and too large for others.
def _transform(
self,
ds: "Dataset",
batch_size: Optional[int],
num_cpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[int] = None,
) -> "Dataset":
transform_type = self._determine_transform_to_use()

# Our user-facing batch format should only be pandas or NumPy, other
# formats {arrow, simple} are internal.
kwargs = self._get_transform_config()
if num_cpus is not None:
kwargs["num_cpus"] = num_cpus
if memory is not None:
kwargs["memory"] = memory
if batch_size is not None:
kwargs["batch_size"] = batch_size
if concurrency is not None:
kwargs["concurrency"] = concurrency

if transform_type == BatchFormat.PANDAS:
return ds.map_batches(
self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs
Expand Down
19 changes: 16 additions & 3 deletions python/ray/data/preprocessors/chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from ray.air.util.data_batch_conversion import BatchFormat
from ray.data import Dataset
Expand Down Expand Up @@ -79,9 +79,22 @@ def fit_transform(self, ds: Dataset) -> Dataset:
ds = preprocessor.fit_transform(ds)
return ds

def _transform(self, ds: Dataset) -> Dataset:
def _transform(
self,
ds: Dataset,
batch_size: Optional[int],
num_cpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[int] = None,
) -> Dataset:
for preprocessor in self.preprocessors:
ds = preprocessor.transform(ds)
ds = preprocessor.transform(
ds,
batch_size=batch_size,
num_cpus=num_cpus,
memory=memory,
concurrency=concurrency,
)
return ds

def _transform_batch(self, df: "DataBatchType") -> "DataBatchType":
Expand Down
33 changes: 19 additions & 14 deletions python/ray/data/tests/preprocessors/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,33 +165,38 @@ def test_fit_twice(mocked_warn):
mocked_warn.assert_called_once_with(msg)


def test_transform_config():
"""Tests that the transform_config of
the Preprocessor is respected during transform."""

def test_transform_all_configs():
batch_size = 2
num_cpus = 2
concurrency = 2
memory = 1024

class DummyPreprocessor(Preprocessor):
_is_fittable = False

def _get_transform_config(self):
return {"batch_size": batch_size}

def _transform_numpy(self, data):
assert ray.get_runtime_context().get_assigned_resources()["CPU"] == num_cpus
assert (
ray.get_runtime_context().get_assigned_resources()["memory"] == memory
)
assert len(data["value"]) == batch_size
return data

def _transform_pandas(self, data):
raise RuntimeError(
"Pandas transform should not be called with numpy batch format."
)

def _get_transform_config(self):
return {"batch_size": 2}

def _determine_transform_to_use(self):
return "numpy"

prep = DummyPreprocessor()
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(4))}))
prep.transform(ds)
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
ds = prep.transform(
ds,
num_cpus=num_cpus,
memory=memory,
concurrency=concurrency,
)
assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4]


@pytest.mark.parametrize("dataset_format", ["simple", "pandas", "arrow"])
Expand Down