Skip to content

Commit eabd75e

Browse files
author
prithagupta
committed
Updated the utils
1 parent 0afec08 commit eabd75e

File tree

2 files changed

+3
-84
lines changed

2 files changed

+3
-84
lines changed

autoqild/utilities/utils.py

Lines changed: 2 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def check_and_delete_corrupt_h5_file(file_path, logger):
228228
logger.info(f"File does not exist '{basename}'")
229229

230230

231-
def standardize_features(x_train, x_test):
231+
def standardize_features(x_train, x_test, scaler=RobustScaler):
232232
"""Standardize the features in the training and test sets using
233233
RobustScaler as a default.
234234
@@ -246,89 +246,8 @@ def standardize_features(x_train, x_test):
246246
x_test : array-like
247247
Standardized test set features.
248248
"""
249-
standardize = Standardize()
249+
standardize = scaler()
250250
x_train = standardize.fit_transform(x_train)
251251
x_test = standardize.transform(x_test)
252252
return x_train, x_test
253253

254-
255-
class Standardize:
256-
"""A class for standardizing features using a specified scaler.
257-
258-
Parameters
259-
----------
260-
scalar : object, optional
261-
The scaling class to use (default is "RobustScaler").
262-
263-
Attributes
264-
----------
265-
n_features : list or None
266-
The list of feature names if `X` is a dictionary.
267-
scalars : dict
268-
A dictionary of scalers for each feature if `X` is a dictionary.
269-
"""
270-
271-
def __init__(self, scalar=RobustScaler):
272-
self.scalar = scalar
273-
self.n_features = None
274-
self.scalars = dict()
275-
276-
def fit(self, X):
277-
"""Fit the scaler to the data.
278-
279-
Parameters
280-
----------
281-
X : array-like or dict
282-
The data to fit the scaler on.
283-
284-
Returns
285-
-------
286-
self : object
287-
Fitted scaler.
288-
"""
289-
if isinstance(X, dict):
290-
self.n_features = list(X.keys())
291-
for k, x in X.items():
292-
scalar = self.scalar()
293-
self.scalars[k] = scalar.fit(x)
294-
if isinstance(X, (np.ndarray, np.generic)):
295-
self.scalar = self.scalar()
296-
self.scalar.fit(X)
297-
self.n_features = X.shape[-1]
298-
299-
def transform(self, X):
300-
"""Apply the scaling transformation to the data.
301-
302-
Parameters
303-
----------
304-
X : array-like or dict
305-
The data to transform.
306-
307-
Returns
308-
-------
309-
X : array-like or dict
310-
The transformed data.
311-
"""
312-
if isinstance(X, dict):
313-
for n in self.n_features:
314-
X[n] = self.scalars[n].transform(X[n])
315-
if isinstance(X, (np.ndarray, np.generic)):
316-
X = self.scalar.transform(X)
317-
return X
318-
319-
def fit_transform(self, X):
320-
"""Fit the scaler and transform the data.
321-
322-
Parameters
323-
----------
324-
X : array-like or dict
325-
The data to fit and transform.
326-
327-
Returns
328-
-------
329-
X : array-like or dict
330-
The transformed data.
331-
"""
332-
self.fit(X)
333-
X = self.transform(X)
334-
return X

docs/source/references.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ List of references for the implemented learning algorithms, AutoML tools and bas
2626
-------------------------
2727
🚀 Baseline MI Estimators
2828
-------------------------
29-
- `Gaussain Mixture Models <https://ieeexplore.ieee.org/document/6889561>`_: Polo et al. (2022)
29+
- `Gaussain Mixture Model (GMM) <https://ieeexplore.ieee.org/document/6889561>`_: Polo et al. (2022)
3030
- `Mutual Information Neural Estimation (MINE) <https://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf>`_: Belghazi et al. (2018)
3131
- `PC-softmax <https://arxiv.org/abs/1911.10688>`_: Qin et al. (2020)

0 commit comments

Comments
 (0)