Skip to content

Commit bd3c4e8

Browse files
author
prithagupta
committed
Updated the Utility Functions
1 parent eabd75e commit bd3c4e8

File tree

2 files changed

+64
-14
lines changed

2 files changed

+64
-14
lines changed

autoqild/utilities/utils.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import traceback
77
import h5py
88
import numpy as np
9-
from sklearn.preprocessing import RobustScaler
9+
from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
1010
import warnings
1111

1212
warnings.filterwarnings("ignore")
@@ -228,26 +228,76 @@ def check_and_delete_corrupt_h5_file(file_path, logger):
228228
logger.info(f"File does not exist '{basename}'")
229229

230230

231-
def standardize_features(x_train, x_test, scaler=RobustScaler):
232-
"""Standardize the features in the training and test sets using
233-
RobustScaler as a default.
231+
def standardize_features(x_train, x_test, scaler=RobustScaler, scaler_params={}):
232+
"""
233+
Standardize the features in the training and test sets using the specified scaler.
234+
235+
The function offers flexibility to choose between `StandardScaler`, `RobustScaler`, and `MinMaxScaler`.
236+
It allows customization of the chosen scaler’s parameters using a dictionary and raises a ValueError
237+
if an unsupported scaler is passed.
234238
235239
Parameters
236240
----------
237-
x_train : array-like
241+
x_train : array-like of shape (n_samples, n_features)
238242
Training set features.
239-
x_test : array-like
243+
x_test : array-like of shape (n_samples, n_features)
240244
Test set features.
245+
scaler : {StandardScaler, RobustScaler, MinMaxScaler}, optional, default=RobustScaler
246+
The scaling class to be used for standardization. Choose from:
247+
- StandardScaler: Standardize features by removing the mean and scaling to unit variance.
248+
- RobustScaler: Scale features using statistics that are robust to outliers.
249+
- MinMaxScaler: Scale features to a given range (usually between 0 and 1).
250+
scaler_params : dict, optional, default={}
251+
Parameters to be passed to the selected scaler. Example: {'with_mean': False} for `StandardScaler`.
241252
242253
Returns
243254
-------
244-
x_train : array-like
255+
x_train : array-like of shape (n_samples, n_features)
245256
Standardized training set features.
246-
x_test : array-like
257+
x_test : array-like of shape (n_samples, n_features)
247258
Standardized test set features.
259+
260+
Raises
261+
------
262+
ValueError
263+
If the specified scaler is not one of `StandardScaler`, `RobustScaler`, or `MinMaxScaler`.
264+
265+
Example
266+
-------
267+
>>> from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
268+
>>> import numpy as np
269+
>>> x_train = np.array([[1, 2], [2, 3], [3, 4]])
270+
>>> x_test = np.array([[4, 5], [5, 6]])
271+
272+
# Example with StandardScaler and a custom parameter
273+
>>> scaler_params = {'with_mean': False}
274+
>>> x_train_scaled, x_test_scaled = standardize_features(
275+
... x_train, x_test, scaler=StandardScaler, scaler_params=scaler_params
276+
... )
277+
278+
# Example with RobustScaler (default)
279+
>>> x_train_scaled, x_test_scaled = standardize_features(x_train, x_test, scaler=RobustScaler)
280+
281+
# Example with MinMaxScaler
282+
>>> x_train_scaled, x_test_scaled = standardize_features(x_train, x_test, scaler=MinMaxScaler)
283+
284+
# Example with an invalid scaler (this will raise a ValueError)
285+
>>> try:
286+
... x_train_scaled, x_test_scaled = standardize_features(x_train, x_test, scaler="InvalidScaler")
287+
... except ValueError as e:
288+
... print(e)
289+
'Invalid scaler specified. Choose from StandardScaler, RobustScaler, or MinMaxScaler.'
248290
"""
249-
standardize = scaler()
250-
x_train = standardize.fit_transform(x_train)
251-
x_test = standardize.transform(x_test)
252-
return x_train, x_test
291+
if scaler not in [StandardScaler, RobustScaler, MinMaxScaler]:
292+
raise ValueError(
293+
"Invalid scaler specified. Choose from StandardScaler, RobustScaler, or MinMaxScaler."
294+
)
295+
296+
# Initialize the chosen scaler with the specified parameters
297+
scaler_instance = scaler(**scaler_params)
253298

299+
# Fit the scaler on the training data and transform both training and test data
300+
x_train = scaler_instance.fit_transform(x_train)
301+
x_test = scaler_instance.transform(x_test)
302+
303+
return x_train, x_test

docs/source/notebooks/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def setup_logging(log_path=None, level=logging.INFO):
8282
logging.getLogger("pytorch").setLevel(logging.ERROR)
8383
logging.getLogger("torch").setLevel(logging.ERROR)
8484
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
85-
os.environ['TF_TRT_LOGGER'] = 'ERROR'
86-
tf.get_logger().setLevel('ERROR')
85+
os.environ["TF_TRT_LOGGER"] = "ERROR"
86+
tf.get_logger().setLevel("ERROR")
8787

8888

8989
def setup_random_seed(random_state=1234):

0 commit comments

Comments
 (0)