diff --git a/python/ray/data/preprocessor.py b/python/ray/data/preprocessor.py index b9182dea3b548..2e7d69ff4265c 100644 --- a/python/ray/data/preprocessor.py +++ b/python/ray/data/preprocessor.py @@ -118,7 +118,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor": self._fitted = True return fitted_ds - def fit_transform(self, ds: "Dataset") -> "Dataset": + def fit_transform( + self, + ds: "Dataset", + *, + transform_num_cpus: Optional[float] = None, + transform_memory: Optional[float] = None, + transform_batch_size: Optional[int] = None, + transform_concurrency: Optional[int] = None, + ) -> "Dataset": """Fit this Preprocessor to the Dataset and then transform the Dataset. Calling it more than once will overwrite all previously fitted state: @@ -127,18 +135,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset": Args: ds: Input Dataset. + transform_num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker. + transform_memory: [experimental] The heap memory in bytes to reserve for each parallel map worker. + transform_batch_size: [experimental] The maximum number of rows to return. + transform_concurrency: [experimental] The maximum number of Ray workers to use concurrently. Returns: ray.data.Dataset: The transformed Dataset. """ self.fit(ds) - return self.transform(ds) + return self.transform( + ds, + num_cpus=transform_num_cpus, + memory=transform_memory, + batch_size=transform_batch_size, + concurrency=transform_concurrency, + ) - def transform(self, ds: "Dataset") -> "Dataset": + def transform( + self, + ds: "Dataset", + *, + batch_size: Optional[int] = None, + num_cpus: Optional[float] = None, + memory: Optional[float] = None, + concurrency: Optional[int] = None, + ) -> "Dataset": """Transform the given dataset. Args: ds: Input Dataset. + batch_size: [experimental] Advanced configuration for adjusting input size for each worker. + num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker. + memory: [experimental] The heap memory in bytes to reserve for each parallel map worker. + concurrency: [experimental] The maximum number of Ray workers to use concurrently. Returns: ray.data.Dataset: The transformed Dataset. @@ -155,7 +185,13 @@ def transform(self, ds: "Dataset") -> "Dataset": "`fit` must be called before `transform`, " "or simply use fit_transform() to run both steps" ) - transformed_ds = self._transform(ds) + transformed_ds = self._transform( + ds, + batch_size=batch_size, + num_cpus=num_cpus, + memory=memory, + concurrency=concurrency, + ) return transformed_ds def transform_batch(self, data: "DataBatchType") -> "DataBatchType": @@ -217,14 +253,28 @@ def _determine_transform_to_use(self) -> BatchFormat: "for Preprocessor transforms." ) - def _transform(self, ds: "Dataset") -> "Dataset": - # TODO(matt): Expose `batch_size` or similar configurability. - # The default may be too small for some datasets and too large for others. + def _transform( + self, + ds: "Dataset", + batch_size: Optional[int], + num_cpus: Optional[float] = None, + memory: Optional[float] = None, + concurrency: Optional[int] = None, + ) -> "Dataset": transform_type = self._determine_transform_to_use() # Our user-facing batch format should only be pandas or NumPy, other # formats {arrow, simple} are internal. kwargs = self._get_transform_config() + if num_cpus is not None: + kwargs["num_cpus"] = num_cpus + if memory is not None: + kwargs["memory"] = memory + if batch_size is not None: + kwargs["batch_size"] = batch_size + if concurrency is not None: + kwargs["concurrency"] = concurrency + if transform_type == BatchFormat.PANDAS: return ds.map_batches( self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs diff --git a/python/ray/data/preprocessors/chain.py b/python/ray/data/preprocessors/chain.py index e608f8cf2f86a..018612ab9abb7 100644 --- a/python/ray/data/preprocessors/chain.py +++ b/python/ray/data/preprocessors/chain.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from ray.air.util.data_batch_conversion import BatchFormat from ray.data import Dataset @@ -79,9 +79,22 @@ def fit_transform(self, ds: Dataset) -> Dataset: ds = preprocessor.fit_transform(ds) return ds - def _transform(self, ds: Dataset) -> Dataset: + def _transform( + self, + ds: Dataset, + batch_size: Optional[int], + num_cpus: Optional[float] = None, + memory: Optional[float] = None, + concurrency: Optional[int] = None, + ) -> Dataset: for preprocessor in self.preprocessors: - ds = preprocessor.transform(ds) + ds = preprocessor.transform( + ds, + batch_size=batch_size, + num_cpus=num_cpus, + memory=memory, + concurrency=concurrency, + ) return ds def _transform_batch(self, df: "DataBatchType") -> "DataBatchType": diff --git a/python/ray/data/tests/preprocessors/test_preprocessors.py b/python/ray/data/tests/preprocessors/test_preprocessors.py index 0adc3d8a8889d..48e2b1b25d757 100644 --- a/python/ray/data/tests/preprocessors/test_preprocessors.py +++ b/python/ray/data/tests/preprocessors/test_preprocessors.py @@ -165,16 +165,23 @@ def test_fit_twice(mocked_warn): mocked_warn.assert_called_once_with(msg) -def test_transform_config(): - """Tests that the transform_config of - the Preprocessor is respected during transform.""" - +def test_transform_all_configs(): batch_size = 2 + num_cpus = 2 + concurrency = 2 + memory = 1024 class DummyPreprocessor(Preprocessor): _is_fittable = False + def _get_transform_config(self): + return {"batch_size": batch_size} + def _transform_numpy(self, data): + assert ray.get_runtime_context().get_assigned_resources()["CPU"] == num_cpus + assert ( + ray.get_runtime_context().get_assigned_resources()["memory"] == memory + ) assert len(data["value"]) == batch_size return data @@ -183,15 +190,18 @@ def _transform_pandas(self, data): "Pandas transform should not be called with numpy batch format." ) - def _get_transform_config(self): - return {"batch_size": 2} - def _determine_transform_to_use(self): return "numpy" prep = DummyPreprocessor() - ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(4))})) - prep.transform(ds) + ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))})) + ds = prep.transform( + ds, + num_cpus=num_cpus, + memory=memory, + concurrency=concurrency, + ) + assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4] @pytest.mark.parametrize("dataset_format", ["simple", "pandas", "arrow"])