Skip to content

Commit 7caf5af

Browse files
committed
[data] move parameters to transform, fit_transform func
Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent d52a9fa commit 7caf5af

File tree

3 files changed

+63
-86
lines changed

3 files changed

+63
-86
lines changed

python/ray/data/preprocessor.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
Union,
1212
List,
1313
Optional,
14-
Callable,
1514
Literal,
16-
Tuple,
1715
)
1816

1917
from ray.air.util.data_batch_conversion import BatchFormat
@@ -57,25 +55,6 @@ class Preprocessor(abc.ABC):
5755
implemented method.
5856
"""
5957

60-
def __init__(
61-
self,
62-
num_cpus: Optional[float] = None,
63-
memory: Optional[float] = None,
64-
batch_size: Union[int, None, Literal["default"]] = None,
65-
concurrency: Optional[int] = None,
66-
):
67-
"""
68-
Args:
69-
num_cpus: The number of CPUs to reserve for each parallel map worker.
70-
memory: The heap memory in bytes to reserve for each parallel map worker.
71-
batch_size: The maximum number of rows to return.
72-
concurrency: The maximum number of Ray workers to use concurrently.
73-
"""
74-
self._num_cpus = num_cpus
75-
self._memory = memory
76-
self._batch_size = batch_size
77-
self._concurrency = concurrency
78-
7958
class FitStatus(str, Enum):
8059
"""The fit status of preprocessor."""
8160

@@ -147,7 +126,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
147126
self._fitted = True
148127
return fitted_ds
149128

150-
def fit_transform(self, ds: "Dataset") -> "Dataset":
129+
def fit_transform(
130+
self,
131+
ds: "Dataset",
132+
*,
133+
transform_num_cpus: Optional[float] = None,
134+
transform_memory: Optional[float] = None,
135+
transform_batch_size: Union[int, None, Literal["default"]] = None,
136+
transform_concurrency: Optional[int] = None,
137+
) -> "Dataset":
151138
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
152139
153140
Calling it more than once will overwrite all previously fitted state:
@@ -156,18 +143,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":
156143
157144
Args:
158145
ds: Input Dataset.
146+
transform_num_cpus: The number of CPUs to reserve for each parallel map worker.
147+
transform_memory: The heap memory in bytes to reserve for each parallel map worker.
148+
transform_batch_size: The maximum number of rows to return.
149+
transform_concurrency: The maximum number of Ray workers to use concurrently.
159150
160151
Returns:
161152
ray.data.Dataset: The transformed Dataset.
162153
"""
163154
self.fit(ds)
164-
return self.transform(ds)
155+
return self.transform(
156+
ds,
157+
num_cpus=transform_num_cpus,
158+
memory=transform_memory,
159+
batch_size=transform_batch_size,
160+
concurrency=transform_concurrency,
161+
)
165162

166-
def transform(self, ds: "Dataset") -> "Dataset":
163+
def transform(
164+
self,
165+
ds: "Dataset",
166+
*,
167+
num_cpus: Optional[float] = None,
168+
memory: Optional[float] = None,
169+
batch_size: Union[int, None, Literal["default"]] = None,
170+
concurrency: Optional[int] = None,
171+
) -> "Dataset":
167172
"""Transform the given dataset.
168173
169174
Args:
170175
ds: Input Dataset.
176+
num_cpus: The number of CPUs to reserve for each parallel map worker.
177+
memory: The heap memory in bytes to reserve for each parallel map worker.
178+
batch_size: The maximum number of rows to return.
179+
concurrency: The maximum number of Ray workers to use concurrently.
171180
172181
Returns:
173182
ray.data.Dataset: The transformed Dataset.
@@ -184,7 +193,7 @@ def transform(self, ds: "Dataset") -> "Dataset":
184193
"`fit` must be called before `transform`, "
185194
"or simply use fit_transform() to run both steps"
186195
)
187-
transformed_ds = self._transform(ds)
196+
transformed_ds = self._transform(ds, num_cpus, memory, batch_size, concurrency)
188197
return transformed_ds
189198

190199
def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
@@ -246,18 +255,27 @@ def _determine_transform_to_use(self) -> BatchFormat:
246255
"for Preprocessor transforms."
247256
)
248257

249-
def _transform(self, ds: "Dataset") -> "Dataset":
250-
# TODO(matt): Expose `batch_size` or similar configurability.
251-
# The default may be too small for some datasets and too large for others.
258+
def _transform(
259+
self,
260+
ds: "Dataset",
261+
num_cpus: Optional[float] = None,
262+
memory: Optional[float] = None,
263+
batch_size: Union[int, None, Literal["default"]] = None,
264+
concurrency: Optional[int] = None,
265+
) -> "Dataset":
252266
transform_type = self._determine_transform_to_use()
253267

254268
# Our user-facing batch format should only be pandas or NumPy, other
255269
# formats {arrow, simple} are internal.
256270
kwargs = self._get_transform_config()
257-
kwargs["num_cpus"] = self._num_cpus
258-
kwargs["memory"] = self._memory
259-
kwargs["batch_size"] = self._batch_size
260-
kwargs["concurrency"] = self._concurrency
271+
if num_cpus is not None:
272+
kwargs["num_cpus"] = num_cpus
273+
if memory is not None:
274+
kwargs["memory"] = memory
275+
if batch_size is not None:
276+
kwargs["batch_size"] = batch_size
277+
if concurrency is not None:
278+
kwargs["concurrency"] = concurrency
261279

262280
if transform_type == BatchFormat.PANDAS:
263281
return ds.map_batches(

python/ray/data/preprocessors/tokenizer.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ class Tokenizer(Preprocessor):
5959
columns will be the same as the input columns. If not None, the length of
6060
``output_columns`` must match the length of ``columns``, othwerwise an error
6161
will be raised.
62-
num_cpus: The number of CPUs to reserve for each parallel map worker.
63-
memory: The heap memory in bytes to reserve for each parallel map worker.
64-
batch_size: The maximum number of rows to return.
65-
concurrency: The maximum number of Ray workers to use concurrently.
6662
"""
6763

6864
_is_fittable = False
@@ -72,18 +68,7 @@ def __init__(
7268
columns: List[str],
7369
tokenization_fn: Optional[Callable[[str], List[str]]] = None,
7470
output_columns: Optional[List[str]] = None,
75-
*,
76-
num_cpus: Optional[float] = None,
77-
memory: Optional[float] = None,
78-
batch_size: Union[int, None, Literal["default"]] = None,
79-
concurrency: Optional[int] = None,
8071
):
81-
super().__init__(
82-
num_cpus=num_cpus,
83-
memory=memory,
84-
batch_size=batch_size,
85-
concurrency=concurrency,
86-
)
8772
self.columns = columns
8873
# TODO(matt): Add a more robust default tokenizer.
8974
self.tokenization_fn = tokenization_fn or simple_split_tokenizer

python/ray/data/tests/preprocessors/test_preprocessors.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,11 @@ def test_fit_twice(mocked_warn):
165165
mocked_warn.assert_called_once_with(msg)
166166

167167

168-
def test_initialization_parameters():
168+
def test_transform_all_configs():
169169
batch_size = 2
170+
num_cpus = 2
171+
concurrency = 2
172+
memory = 1024
170173

171174
class DummyPreprocessor(Preprocessor):
172175
_is_fittable = False
@@ -175,56 +178,27 @@ def _get_transform_config(self):
175178
return {"batch_size": batch_size}
176179

177180
def _transform_numpy(self, data):
181+
assert ray.get_runtime_context().get_assigned_resources()["CPU"] == num_cpus
178182
assert (
179-
ray.get_runtime_context().get_assigned_resources()["CPU"]
180-
== self._num_cpus
183+
ray.get_runtime_context().get_assigned_resources()["memory"] == memory
181184
)
182185
assert len(data["value"]) == batch_size
183186
return data
184187

185188
def _determine_transform_to_use(self):
186189
return "numpy"
187190

188-
prep = DummyPreprocessor(
189-
num_cpus=2,
190-
concurrency=2,
191-
batch_size=batch_size,
192-
)
191+
prep = DummyPreprocessor()
193192
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
194-
ds = prep.transform(ds)
195-
193+
ds = prep.transform(
194+
ds,
195+
num_cpus=num_cpus,
196+
memory=memory,
197+
concurrency=concurrency,
198+
)
196199
assert [x["value"] for x in ds.take(5)] == [0, 1, 2, 3, 4]
197200

198201

199-
def test_transform_config():
200-
"""Tests that the transform_config of
201-
the Preprocessor is respected during transform."""
202-
203-
batch_size = 2
204-
205-
class DummyPreprocessor(Preprocessor):
206-
_is_fittable = False
207-
208-
def _transform_numpy(self, data):
209-
assert len(data["value"]) == batch_size
210-
return data
211-
212-
def _transform_pandas(self, data):
213-
raise RuntimeError(
214-
"Pandas transform should not be called with numpy batch format."
215-
)
216-
217-
def _get_transform_config(self):
218-
return {"batch_size": 2}
219-
220-
def _determine_transform_to_use(self):
221-
return "numpy"
222-
223-
prep = DummyPreprocessor()
224-
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(4))}))
225-
prep.transform(ds)
226-
227-
228202
@pytest.mark.parametrize("dataset_format", ["simple", "pandas", "arrow"])
229203
def test_transform_all_formats(create_dummy_preprocessors, dataset_format):
230204
(

0 commit comments

Comments
 (0)