Skip to content

Commit 9900995

Browse files
committed
[data] move parameters to transform, fit_transform func
Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent d52a9fa commit 9900995

File tree

3 files changed

+63
-84
lines changed

3 files changed

+63
-84
lines changed

python/ray/data/preprocessor.py

+50-30
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,6 @@ class Preprocessor(abc.ABC):
5757
implemented method.
5858
"""
5959

60-
def __init__(
61-
self,
62-
num_cpus: Optional[float] = None,
63-
memory: Optional[float] = None,
64-
batch_size: Union[int, None, Literal["default"]] = None,
65-
concurrency: Optional[int] = None,
66-
):
67-
"""
68-
Args:
69-
num_cpus: The number of CPUs to reserve for each parallel map worker.
70-
memory: The heap memory in bytes to reserve for each parallel map worker.
71-
batch_size: The maximum number of rows to return.
72-
concurrency: The maximum number of Ray workers to use concurrently.
73-
"""
74-
self._num_cpus = num_cpus
75-
self._memory = memory
76-
self._batch_size = batch_size
77-
self._concurrency = concurrency
78-
7960
class FitStatus(str, Enum):
8061
"""The fit status of preprocessor."""
8162

@@ -147,7 +128,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
147128
self._fitted = True
148129
return fitted_ds
149130

150-
def fit_transform(self, ds: "Dataset") -> "Dataset":
131+
def fit_transform(
132+
self,
133+
ds: "Dataset",
134+
*,
135+
transform_num_cpus: Optional[float] = None,
136+
transform_memory: Optional[float] = None,
137+
transform_batch_size: Union[int, None, Literal["default"]] = None,
138+
transform_concurrency: Optional[int] = None,
139+
) -> "Dataset":
151140
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
152141
153142
Calling it more than once will overwrite all previously fitted state:
@@ -156,18 +145,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":
156145
157146
Args:
158147
ds: Input Dataset.
148+
transform_num_cpus: The number of CPUs to reserve for each parallel map worker.
149+
transform_memory: The heap memory in bytes to reserve for each parallel map worker.
150+
transform_batch_size: The maximum number of rows to return.
151+
transform_concurrency: The maximum number of Ray workers to use concurrently.
159152
160153
Returns:
161154
ray.data.Dataset: The transformed Dataset.
162155
"""
163156
self.fit(ds)
164-
return self.transform(ds)
157+
return self.transform(
158+
ds,
159+
num_cpus=transform_num_cpus,
160+
memory=transform_memory,
161+
batch_size=transform_batch_size,
162+
concurrency=transform_concurrency,
163+
)
165164

166-
def transform(self, ds: "Dataset") -> "Dataset":
165+
def transform(
166+
self,
167+
ds: "Dataset",
168+
*,
169+
num_cpus: Optional[float] = None,
170+
memory: Optional[float] = None,
171+
batch_size: Union[int, None, Literal["default"]] = None,
172+
concurrency: Optional[int] = None,
173+
) -> "Dataset":
167174
"""Transform the given dataset.
168175
169176
Args:
170177
ds: Input Dataset.
178+
num_cpus: The number of CPUs to reserve for each parallel map worker.
179+
memory: The heap memory in bytes to reserve for each parallel map worker.
180+
batch_size: The maximum number of rows to return.
181+
concurrency: The maximum number of Ray workers to use concurrently.
171182
172183
Returns:
173184
ray.data.Dataset: The transformed Dataset.
@@ -184,7 +195,7 @@ def transform(self, ds: "Dataset") -> "Dataset":
184195
"`fit` must be called before `transform`, "
185196
"or simply use fit_transform() to run both steps"
186197
)
187-
transformed_ds = self._transform(ds)
198+
transformed_ds = self._transform(ds, num_cpus, memory, batch_size, concurrency)
188199
return transformed_ds
189200

190201
def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
@@ -246,18 +257,27 @@ def _determine_transform_to_use(self) -> BatchFormat:
246257
"for Preprocessor transforms."
247258
)
248259

249-
def _transform(self, ds: "Dataset") -> "Dataset":
250-
# TODO(matt): Expose `batch_size` or similar configurability.
251-
# The default may be too small for some datasets and too large for others.
260+
def _transform(
261+
self,
262+
ds: "Dataset",
263+
num_cpus: Optional[float] = None,
264+
memory: Optional[float] = None,
265+
batch_size: Union[int, None, Literal["default"]] = None,
266+
concurrency: Optional[int] = None,
267+
) -> "Dataset":
252268
transform_type = self._determine_transform_to_use()
253269

254270
# Our user-facing batch format should only be pandas or NumPy, other
255271
# formats {arrow, simple} are internal.
256272
kwargs = self._get_transform_config()
257-
kwargs["num_cpus"] = self._num_cpus
258-
kwargs["memory"] = self._memory
259-
kwargs["batch_size"] = self._batch_size
260-
kwargs["concurrency"] = self._concurrency
273+
if num_cpus is not None:
274+
kwargs["num_cpus"] = num_cpus
275+
if memory is not None:
276+
kwargs["memory"] = memory
277+
if batch_size is not None:
278+
kwargs["batch_size"] = batch_size
279+
if concurrency is not None:
280+
kwargs["concurrency"] = concurrency
261281

262282
if transform_type == BatchFormat.PANDAS:
263283
return ds.map_batches(

python/ray/data/preprocessors/tokenizer.py

-15
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ class Tokenizer(Preprocessor):
5959
columns will be the same as the input columns. If not None, the length of
6060
``output_columns`` must match the length of ``columns``, othwerwise an error
6161
will be raised.
62-
num_cpus: The number of CPUs to reserve for each parallel map worker.
63-
memory: The heap memory in bytes to reserve for each parallel map worker.
64-
batch_size: The maximum number of rows to return.
65-
concurrency: The maximum number of Ray workers to use concurrently.
6662
"""
6763

6864
_is_fittable = False
@@ -72,18 +68,7 @@ def __init__(
7268
columns: List[str],
7369
tokenization_fn: Optional[Callable[[str], List[str]]] = None,
7470
output_columns: Optional[List[str]] = None,
75-
*,
76-
num_cpus: Optional[float] = None,
77-
memory: Optional[float] = None,
78-
batch_size: Union[int, None, Literal["default"]] = None,
79-
concurrency: Optional[int] = None,
8071
):
81-
super().__init__(
82-
num_cpus=num_cpus,
83-
memory=memory,
84-
batch_size=batch_size,
85-
concurrency=concurrency,
86-
)
8772
self.columns = columns
8873
# TODO(matt): Add a more robust default tokenizer.
8974
self.tokenization_fn = tokenization_fn or simple_split_tokenizer

python/ray/data/tests/preprocessors/test_preprocessors.py

+13-39
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,11 @@ def test_fit_twice(mocked_warn):
165165
mocked_warn.assert_called_once_with(msg)
166166

167167

168-
def test_initialization_parameters():
168+
def test_transform_all_configs():
169169
batch_size = 2
170+
num_cpus = 2
171+
concurrency = 2
172+
memory = 1024
170173

171174
class DummyPreprocessor(Preprocessor):
172175
_is_fittable = False
@@ -175,56 +178,27 @@ def _get_transform_config(self):
175178
return {"batch_size": batch_size}
176179

177180
def _transform_numpy(self, data):
181+
assert ray.get_runtime_context().get_assigned_resources()["CPU"] == num_cpus
178182
assert (
179-
ray.get_runtime_context().get_assigned_resources()["CPU"]
180-
== self._num_cpus
183+
ray.get_runtime_context().get_assigned_resources()["memory"] == memory
181184
)
182185
assert len(data["value"]) == batch_size
183186
return data
184187

185188
def _determine_transform_to_use(self):
186189
return "numpy"
187190

188-
prep = DummyPreprocessor(
189-
num_cpus=2,
190-
concurrency=2,
191-
batch_size=batch_size,
192-
)
191+
prep = DummyPreprocessor()
193192
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
194-
ds = prep.transform(ds)
195-
193+
ds = prep.transform(
194+
ds,
195+
num_cpus=num_cpus,
196+
memory=memory,
197+
concurrency=concurrency,
198+
)
196199
assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4]
197200

198201

199-
def test_transform_config():
200-
"""Tests that the transform_config of
201-
the Preprocessor is respected during transform."""
202-
203-
batch_size = 2
204-
205-
class DummyPreprocessor(Preprocessor):
206-
_is_fittable = False
207-
208-
def _transform_numpy(self, data):
209-
assert len(data["value"]) == batch_size
210-
return data
211-
212-
def _transform_pandas(self, data):
213-
raise RuntimeError(
214-
"Pandas transform should not be called with numpy batch format."
215-
)
216-
217-
def _get_transform_config(self):
218-
return {"batch_size": 2}
219-
220-
def _determine_transform_to_use(self):
221-
return "numpy"
222-
223-
prep = DummyPreprocessor()
224-
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(4))}))
225-
prep.transform(ds)
226-
227-
228202
@pytest.mark.parametrize("dataset_format", ["simple", "pandas", "arrow"])
229203
def test_transform_all_formats(create_dummy_preprocessors, dataset_format):
230204
(

0 commit comments

Comments
 (0)