11
11
Union ,
12
12
List ,
13
13
Optional ,
14
- Callable ,
15
14
Literal ,
16
- Tuple ,
17
15
)
18
16
19
17
from ray .air .util .data_batch_conversion import BatchFormat
@@ -57,25 +55,6 @@ class Preprocessor(abc.ABC):
57
55
implemented method.
58
56
"""
59
57
60
- def __init__ (
61
- self ,
62
- num_cpus : Optional [float ] = None ,
63
- memory : Optional [float ] = None ,
64
- batch_size : Union [int , None , Literal ["default" ]] = None ,
65
- concurrency : Optional [int ] = None ,
66
- ):
67
- """
68
- Args:
69
- num_cpus: The number of CPUs to reserve for each parallel map worker.
70
- memory: The heap memory in bytes to reserve for each parallel map worker.
71
- batch_size: The maximum number of rows to return.
72
- concurrency: The maximum number of Ray workers to use concurrently.
73
- """
74
- self ._num_cpus = num_cpus
75
- self ._memory = memory
76
- self ._batch_size = batch_size
77
- self ._concurrency = concurrency
78
-
79
58
class FitStatus (str , Enum ):
80
59
"""The fit status of preprocessor."""
81
60
@@ -147,7 +126,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
147
126
self ._fitted = True
148
127
return fitted_ds
149
128
150
- def fit_transform (self , ds : "Dataset" ) -> "Dataset" :
129
+ def fit_transform (
130
+ self ,
131
+ ds : "Dataset" ,
132
+ * ,
133
+ transform_num_cpus : Optional [float ] = None ,
134
+ transform_memory : Optional [float ] = None ,
135
+ transform_batch_size : Union [int , None , Literal ["default" ]] = None ,
136
+ transform_concurrency : Optional [int ] = None ,
137
+ ) -> "Dataset" :
151
138
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
152
139
153
140
Calling it more than once will overwrite all previously fitted state:
@@ -156,18 +143,40 @@ def fit_transform(self, ds: "Dataset") -> "Dataset":
156
143
157
144
Args:
158
145
ds: Input Dataset.
146
+ transform_num_cpus: The number of CPUs to reserve for each parallel map worker.
147
+ transform_memory: The heap memory in bytes to reserve for each parallel map worker.
148
+ transform_batch_size: The maximum number of rows to return.
149
+ transform_concurrency: The maximum number of Ray workers to use concurrently.
159
150
160
151
Returns:
161
152
ray.data.Dataset: The transformed Dataset.
162
153
"""
163
154
self .fit (ds )
164
- return self .transform (ds )
155
+ return self .transform (
156
+ ds ,
157
+ num_cpus = transform_num_cpus ,
158
+ memory = transform_memory ,
159
+ batch_size = transform_batch_size ,
160
+ concurrency = transform_concurrency ,
161
+ )
165
162
166
- def transform (self , ds : "Dataset" ) -> "Dataset" :
163
+ def transform (
164
+ self ,
165
+ ds : "Dataset" ,
166
+ * ,
167
+ num_cpus : Optional [float ] = None ,
168
+ memory : Optional [float ] = None ,
169
+ batch_size : Union [int , None , Literal ["default" ]] = None ,
170
+ concurrency : Optional [int ] = None ,
171
+ ) -> "Dataset" :
167
172
"""Transform the given dataset.
168
173
169
174
Args:
170
175
ds: Input Dataset.
176
+ num_cpus: The number of CPUs to reserve for each parallel map worker.
177
+ memory: The heap memory in bytes to reserve for each parallel map worker.
178
+ batch_size: The maximum number of rows to return.
179
+ concurrency: The maximum number of Ray workers to use concurrently.
171
180
172
181
Returns:
173
182
ray.data.Dataset: The transformed Dataset.
@@ -184,7 +193,7 @@ def transform(self, ds: "Dataset") -> "Dataset":
184
193
"`fit` must be called before `transform`, "
185
194
"or simply use fit_transform() to run both steps"
186
195
)
187
- transformed_ds = self ._transform (ds )
196
+ transformed_ds = self ._transform (ds , num_cpus , memory , batch_size , concurrency )
188
197
return transformed_ds
189
198
190
199
def transform_batch (self , data : "DataBatchType" ) -> "DataBatchType" :
@@ -246,18 +255,27 @@ def _determine_transform_to_use(self) -> BatchFormat:
246
255
"for Preprocessor transforms."
247
256
)
248
257
249
- def _transform (self , ds : "Dataset" ) -> "Dataset" :
250
- # TODO(matt): Expose `batch_size` or similar configurability.
251
- # The default may be too small for some datasets and too large for others.
258
+ def _transform (
259
+ self ,
260
+ ds : "Dataset" ,
261
+ num_cpus : Optional [float ] = None ,
262
+ memory : Optional [float ] = None ,
263
+ batch_size : Union [int , None , Literal ["default" ]] = None ,
264
+ concurrency : Optional [int ] = None ,
265
+ ) -> "Dataset" :
252
266
transform_type = self ._determine_transform_to_use ()
253
267
254
268
# Our user-facing batch format should only be pandas or NumPy, other
255
269
# formats {arrow, simple} are internal.
256
270
kwargs = self ._get_transform_config ()
257
- kwargs ["num_cpus" ] = self ._num_cpus
258
- kwargs ["memory" ] = self ._memory
259
- kwargs ["batch_size" ] = self ._batch_size
260
- kwargs ["concurrency" ] = self ._concurrency
271
+ if num_cpus is not None :
272
+ kwargs ["num_cpus" ] = num_cpus
273
+ if memory is not None :
274
+ kwargs ["memory" ] = memory
275
+ if batch_size is not None :
276
+ kwargs ["batch_size" ] = batch_size
277
+ if concurrency is not None :
278
+ kwargs ["concurrency" ] = concurrency
261
279
262
280
if transform_type == BatchFormat .PANDAS :
263
281
return ds .map_batches (
0 commit comments