@@ -57,25 +57,6 @@ class Preprocessor(abc.ABC):
57
57
implemented method.
58
58
"""
59
59
60
- def __init__ (
61
- self ,
62
- num_cpus : Optional [float ] = None ,
63
- memory : Optional [float ] = None ,
64
- batch_size : Union [int , None , Literal ["default" ]] = None ,
65
- concurrency : Optional [int ] = None ,
66
- ):
67
- """
68
- Args:
69
- num_cpus: The number of CPUs to reserve for each parallel map worker.
70
- memory: The heap memory in bytes to reserve for each parallel map worker.
71
- batch_size: The maximum number of rows to return.
72
- concurrency: The maximum number of Ray workers to use concurrently.
73
- """
74
- self ._num_cpus = num_cpus
75
- self ._memory = memory
76
- self ._batch_size = batch_size
77
- self ._concurrency = concurrency
78
-
79
60
class FitStatus (str , Enum ):
80
61
"""The fit status of preprocessor."""
81
62
@@ -147,7 +128,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
147
128
self ._fitted = True
148
129
return fitted_ds
149
130
150
- def fit_transform (self , ds : "Dataset" ) -> "Dataset" :
131
+ def fit_transform (
132
+ self ,
133
+ ds : "Dataset" ,
134
+ * ,
135
+ transform_num_cpus : Optional [float ] = None ,
136
+ transform_memory : Optional [float ] = None ,
137
+ transform_batch_size : Union [int , None , Literal ["default" ]] = None ,
138
+ transform_concurrency : Optional [int ] = None ,
139
+ ) -> "Dataset" :
151
140
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
152
141
153
142
Calling it more than once will overwrite all previously fitted state:
@@ -156,18 +145,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":
156
145
157
146
Args:
158
147
ds: Input Dataset.
148
+ transform_num_cpus: The number of CPUs to reserve for each parallel map worker.
149
+ transform_memory: The heap memory in bytes to reserve for each parallel map worker.
150
+ transform_batch_size: The maximum number of rows to return.
151
+ transform_concurrency: The maximum number of Ray workers to use concurrently.
159
152
160
153
Returns:
161
154
ray.data.Dataset: The transformed Dataset.
162
155
"""
163
156
self .fit (ds )
164
- return self .transform (ds )
157
+ return self .transform (
158
+ ds ,
159
+ num_cpus = transform_num_cpus ,
160
+ memory = transform_memory ,
161
+ batch_size = transform_batch_size ,
162
+ concurrency = transform_concurrency ,
163
+ )
165
164
166
- def transform (self , ds : "Dataset" ) -> "Dataset" :
165
+ def transform (
166
+ self ,
167
+ ds : "Dataset" ,
168
+ * ,
169
+ num_cpus : Optional [float ] = None ,
170
+ memory : Optional [float ] = None ,
171
+ batch_size : Union [int , None , Literal ["default" ]] = None ,
172
+ concurrency : Optional [int ] = None ,
173
+ ) -> "Dataset" :
167
174
"""Transform the given dataset.
168
175
169
176
Args:
170
177
ds: Input Dataset.
178
+ num_cpus: The number of CPUs to reserve for each parallel map worker.
179
+ memory: The heap memory in bytes to reserve for each parallel map worker.
180
+ batch_size: The maximum number of rows to return.
181
+ concurrency: The maximum number of Ray workers to use concurrently.
171
182
172
183
Returns:
173
184
ray.data.Dataset: The transformed Dataset.
@@ -184,7 +195,7 @@ def transform(self, ds: "Dataset") -> "Dataset":
184
195
"`fit` must be called before `transform`, "
185
196
"or simply use fit_transform() to run both steps"
186
197
)
187
- transformed_ds = self ._transform (ds )
198
+ transformed_ds = self ._transform (ds , num_cpus , memory , batch_size , concurrency )
188
199
return transformed_ds
189
200
190
201
def transform_batch (self , data : "DataBatchType" ) -> "DataBatchType" :
@@ -246,18 +257,27 @@ def _determine_transform_to_use(self) -> BatchFormat:
246
257
"for Preprocessor transforms."
247
258
)
248
259
249
- def _transform (self , ds : "Dataset" ) -> "Dataset" :
250
- # TODO(matt): Expose `batch_size` or similar configurability.
251
- # The default may be too small for some datasets and too large for others.
260
+ def _transform (
261
+ self ,
262
+ ds : "Dataset" ,
263
+ num_cpus : Optional [float ] = None ,
264
+ memory : Optional [float ] = None ,
265
+ batch_size : Union [int , None , Literal ["default" ]] = None ,
266
+ concurrency : Optional [int ] = None ,
267
+ ) -> "Dataset" :
252
268
transform_type = self ._determine_transform_to_use ()
253
269
254
270
# Our user-facing batch format should only be pandas or NumPy, other
255
271
# formats {arrow, simple} are internal.
256
272
kwargs = self ._get_transform_config ()
257
- kwargs ["num_cpus" ] = self ._num_cpus
258
- kwargs ["memory" ] = self ._memory
259
- kwargs ["batch_size" ] = self ._batch_size
260
- kwargs ["concurrency" ] = self ._concurrency
273
+ if num_cpus is not None :
274
+ kwargs ["num_cpus" ] = num_cpus
275
+ if memory is not None :
276
+ kwargs ["memory" ] = memory
277
+ if batch_size is not None :
278
+ kwargs ["batch_size" ] = batch_size
279
+ if concurrency is not None :
280
+ kwargs ["concurrency" ] = concurrency
261
281
262
282
if transform_type == BatchFormat .PANDAS :
263
283
return ds .map_batches (
0 commit comments