@@ -228,7 +228,7 @@ def check_and_delete_corrupt_h5_file(file_path, logger):
228
228
logger .info (f"File does not exist '{ basename } '" )
229
229
230
230
231
- def standardize_features (x_train , x_test ):
231
+ def standardize_features (x_train , x_test , scaler = RobustScaler ):
232
232
"""Standardize the features in the training and test sets using
233
233
RobustScaler as a default.
234
234
@@ -246,89 +246,8 @@ def standardize_features(x_train, x_test):
246
246
x_test : array-like
247
247
Standardized test set features.
248
248
"""
249
- standardize = Standardize ()
249
+ standardize = scaler ()
250
250
x_train = standardize .fit_transform (x_train )
251
251
x_test = standardize .transform (x_test )
252
252
return x_train , x_test
253
253
254
-
255
- class Standardize :
256
- """A class for standardizing features using a specified scaler.
257
-
258
- Parameters
259
- ----------
260
- scalar : object, optional
261
- The scaling class to use (default is "RobustScaler").
262
-
263
- Attributes
264
- ----------
265
- n_features : list or None
266
- The list of feature names if `X` is a dictionary.
267
- scalars : dict
268
- A dictionary of scalers for each feature if `X` is a dictionary.
269
- """
270
-
271
- def __init__ (self , scalar = RobustScaler ):
272
- self .scalar = scalar
273
- self .n_features = None
274
- self .scalars = dict ()
275
-
276
- def fit (self , X ):
277
- """Fit the scaler to the data.
278
-
279
- Parameters
280
- ----------
281
- X : array-like or dict
282
- The data to fit the scaler on.
283
-
284
- Returns
285
- -------
286
- self : object
287
- Fitted scaler.
288
- """
289
- if isinstance (X , dict ):
290
- self .n_features = list (X .keys ())
291
- for k , x in X .items ():
292
- scalar = self .scalar ()
293
- self .scalars [k ] = scalar .fit (x )
294
- if isinstance (X , (np .ndarray , np .generic )):
295
- self .scalar = self .scalar ()
296
- self .scalar .fit (X )
297
- self .n_features = X .shape [- 1 ]
298
-
299
- def transform (self , X ):
300
- """Apply the scaling transformation to the data.
301
-
302
- Parameters
303
- ----------
304
- X : array-like or dict
305
- The data to transform.
306
-
307
- Returns
308
- -------
309
- X : array-like or dict
310
- The transformed data.
311
- """
312
- if isinstance (X , dict ):
313
- for n in self .n_features :
314
- X [n ] = self .scalars [n ].transform (X [n ])
315
- if isinstance (X , (np .ndarray , np .generic )):
316
- X = self .scalar .transform (X )
317
- return X
318
-
319
- def fit_transform (self , X ):
320
- """Fit the scaler and transform the data.
321
-
322
- Parameters
323
- ----------
324
- X : array-like or dict
325
- The data to fit and transform.
326
-
327
- Returns
328
- -------
329
- X : array-like or dict
330
- The transformed data.
331
- """
332
- self .fit (X )
333
- X = self .transform (X )
334
- return X
0 commit comments