Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,27 @@ class DataTransformer(object):
Discrete columns are encoded using a scikit-learn OneHotEncoder.
"""

def __init__(self, max_clusters=10, weight_threshold=0.005):
def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None):
"""Create a data transformer.

Args:
max_clusters (int):
Maximum number of Gaussian distributions in Bayesian GMM.
weight_threshold (float):
Weight threshold for a Gaussian distribution to be kept.
max_gm_samples (int):
Maximum number of samples to use during GMM fit.
"""
self._max_clusters = max_clusters
self._weight_threshold = weight_threshold
self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples

def _fit_continuous(self, column_name, raw_column_data):
"""Train Bayesian GMM for continuous column."""
if self._max_gm_samples <= raw_column_data.shape[0]:
raw_column_data = np.random.choice(raw_column_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that when it comes to this kind of line breaking this indentation is better:

    raw_column_data = np.random.choice(
        raw_column_data,
        size=self._max_gm_samples,
        replace=False
    )

size=self._max_gm_samples,
replace=False)
gm = BayesianGaussianMixture(
self._max_clusters,
weight_concentration_prior_type='dirichlet_process',
Expand Down
7 changes: 5 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
if invalid_columns:
raise ValueError('Invalid columns found: {}'.format(invalid_columns))

def fit(self, train_data, discrete_columns=tuple(), epochs=None):
def fit(self, train_data, discrete_columns=tuple(), epochs=None,
data_transformer_params={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data_transformer_params should be moved to the __init__ and be asigned as self.data_transformer_params. (Use deepcopy if needed).

"""Fit the CTGAN Synthesizer models to the training data.

Args:
Expand All @@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
data_transformer_params (dict):
Dictionary of parameters for ``DataTransformer`` initialization.
"""
self._validate_discrete_columns(train_data, discrete_columns)

Expand All @@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
DeprecationWarning
)

self._transformer = DataTransformer()
self._transformer = DataTransformer(**data_transformer_params)
self._transformer.fit(train_data, discrete_columns)

train_data = self._transformer.transform(train_data)
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():

with pytest.raises(ValueError):
ctgan.sample(1, 'discrete', "d")


def test_ctgan_data_transformer_params():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.

data = pd.DataFrame({
'continuous': np.random.random(1000)
})

ctgan = CTGANSynthesizer(epochs=1)
ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100})

assert ctgan._transformer._max_gm_samples == 100