Skip to content

Commit bd5b17c

Browse files
xingyu-longGokuMohandas
authored andcommitted
[Data] support num_cpus, memory, concurrency, batch_size for preprocess (#52574)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? As title, we'd like to support above parameters for preprocess <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number Close #52448 <!-- For example: "Closes #1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent 34b89ba commit bd5b17c

File tree

3 files changed

+92
-19
lines changed

3 files changed

+92
-19
lines changed

python/ray/data/preprocessor.py

+57-7
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
118118
self._fitted = True
119119
return fitted_ds
120120

121-
def fit_transform(self, ds: "Dataset") -> "Dataset":
121+
def fit_transform(
122+
self,
123+
ds: "Dataset",
124+
*,
125+
transform_num_cpus: Optional[float] = None,
126+
transform_memory: Optional[float] = None,
127+
transform_batch_size: Optional[int] = None,
128+
transform_concurrency: Optional[int] = None,
129+
) -> "Dataset":
122130
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
123131
124132
Calling it more than once will overwrite all previously fitted state:
@@ -127,18 +135,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":
127135
128136
Args:
129137
ds: Input Dataset.
138+
transform_num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
139+
transform_memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
140+
transform_batch_size: [experimental] The maximum number of rows to return.
141+
transform_concurrency: [experimental] The maximum number of Ray workers to use concurrently.
130142
131143
Returns:
132144
ray.data.Dataset: The transformed Dataset.
133145
"""
134146
self.fit(ds)
135-
return self.transform(ds)
147+
return self.transform(
148+
ds,
149+
num_cpus=transform_num_cpus,
150+
memory=transform_memory,
151+
batch_size=transform_batch_size,
152+
concurrency=transform_concurrency,
153+
)
136154

137-
def transform(self, ds: "Dataset") -> "Dataset":
155+
def transform(
156+
self,
157+
ds: "Dataset",
158+
*,
159+
batch_size: Optional[int] = None,
160+
num_cpus: Optional[float] = None,
161+
memory: Optional[float] = None,
162+
concurrency: Optional[int] = None,
163+
) -> "Dataset":
138164
"""Transform the given dataset.
139165
140166
Args:
141167
ds: Input Dataset.
168+
batch_size: [experimental] Advanced configuration for adjusting input size for each worker.
169+
num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
170+
memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
171+
concurrency: [experimental] The maximum number of Ray workers to use concurrently.
142172
143173
Returns:
144174
ray.data.Dataset: The transformed Dataset.
@@ -155,7 +185,13 @@ def transform(self, ds: "Dataset") -> "Dataset":
155185
"`fit` must be called before `transform`, "
156186
"or simply use fit_transform() to run both steps"
157187
)
158-
transformed_ds = self._transform(ds)
188+
transformed_ds = self._transform(
189+
ds,
190+
batch_size=batch_size,
191+
num_cpus=num_cpus,
192+
memory=memory,
193+
concurrency=concurrency,
194+
)
159195
return transformed_ds
160196

161197
def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
@@ -217,14 +253,28 @@ def _determine_transform_to_use(self) -> BatchFormat:
217253
"for Preprocessor transforms."
218254
)
219255

220-
def _transform(self, ds: "Dataset") -> "Dataset":
221-
# TODO(matt): Expose `batch_size` or similar configurability.
222-
# The default may be too small for some datasets and too large for others.
256+
def _transform(
257+
self,
258+
ds: "Dataset",
259+
batch_size: Optional[int],
260+
num_cpus: Optional[float] = None,
261+
memory: Optional[float] = None,
262+
concurrency: Optional[int] = None,
263+
) -> "Dataset":
223264
transform_type = self._determine_transform_to_use()
224265

225266
# Our user-facing batch format should only be pandas or NumPy, other
226267
# formats {arrow, simple} are internal.
227268
kwargs = self._get_transform_config()
269+
if num_cpus is not None:
270+
kwargs["num_cpus"] = num_cpus
271+
if memory is not None:
272+
kwargs["memory"] = memory
273+
if batch_size is not None:
274+
kwargs["batch_size"] = batch_size
275+
if concurrency is not None:
276+
kwargs["concurrency"] = concurrency
277+
228278
if transform_type == BatchFormat.PANDAS:
229279
return ds.map_batches(
230280
self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs

python/ray/data/preprocessors/chain.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
22

33
from ray.air.util.data_batch_conversion import BatchFormat
44
from ray.data import Dataset
@@ -79,9 +79,22 @@ def fit_transform(self, ds: Dataset) -> Dataset:
7979
ds = preprocessor.fit_transform(ds)
8080
return ds
8181

82-
def _transform(self, ds: Dataset) -> Dataset:
82+
def _transform(
83+
self,
84+
ds: Dataset,
85+
batch_size: Optional[int],
86+
num_cpus: Optional[float] = None,
87+
memory: Optional[float] = None,
88+
concurrency: Optional[int] = None,
89+
) -> Dataset:
8390
for preprocessor in self.preprocessors:
84-
ds = preprocessor.transform(ds)
91+
ds = preprocessor.transform(
92+
ds,
93+
batch_size=batch_size,
94+
num_cpus=num_cpus,
95+
memory=memory,
96+
concurrency=concurrency,
97+
)
8598
return ds
8699

87100
def _transform_batch(self, df: "DataBatchType") -> "DataBatchType":

python/ray/data/tests/preprocessors/test_preprocessors.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,23 @@ def test_fit_twice(mocked_warn):
165165
mocked_warn.assert_called_once_with(msg)
166166

167167

168-
def test_transform_config():
169-
"""Tests that the transform_config of
170-
the Preprocessor is respected during transform."""
171-
168+
def test_transform_all_configs():
172169
batch_size = 2
170+
num_cpus = 2
171+
concurrency = 2
172+
memory = 1024
173173

174174
class DummyPreprocessor(Preprocessor):
175175
_is_fittable = False
176176

177+
def _get_transform_config(self):
178+
return {"batch_size": batch_size}
179+
177180
def _transform_numpy(self, data):
181+
assert ray.get_runtime_context().get_assigned_resources()["CPU"] == num_cpus
182+
assert (
183+
ray.get_runtime_context().get_assigned_resources()["memory"] == memory
184+
)
178185
assert len(data["value"]) == batch_size
179186
return data
180187

@@ -183,15 +190,18 @@ def _transform_pandas(self, data):
183190
"Pandas transform should not be called with numpy batch format."
184191
)
185192

186-
def _get_transform_config(self):
187-
return {"batch_size": 2}
188-
189193
def _determine_transform_to_use(self):
190194
return "numpy"
191195

192196
prep = DummyPreprocessor()
193-
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(4))}))
194-
prep.transform(ds)
197+
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
198+
ds = prep.transform(
199+
ds,
200+
num_cpus=num_cpus,
201+
memory=memory,
202+
concurrency=concurrency,
203+
)
204+
assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4]
195205

196206

197207
@pytest.mark.parametrize("dataset_format", ["simple", "pandas", "arrow"])

0 commit comments

Comments
 (0)