-
Notifications
You must be signed in to change notification settings - Fork 324
DataTransformer init parameters #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
4b0d505
4560e78
96a6321
8669599
2a3222f
1b40159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns): | |
| if invalid_columns: | ||
| raise ValueError('Invalid columns found: {}'.format(invalid_columns)) | ||
|
|
||
| def fit(self, train_data, discrete_columns=tuple(), epochs=None): | ||
| def fit(self, train_data, discrete_columns=tuple(), epochs=None, | ||
| data_transformer_params={}): | ||
|
||
| """Fit the CTGAN Synthesizer models to the training data. | ||
|
|
||
| Args: | ||
|
|
@@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): | |
| Vector. If ``train_data`` is a Numpy array, this list should | ||
| contain the integer indices of the columns. Otherwise, if it is | ||
| a ``pandas.DataFrame``, this list should contain the column names. | ||
| data_transformer_params (dict): | ||
| Dictionary of parameters for ``DataTransformer`` initialization. | ||
| """ | ||
| self._validate_discrete_columns(train_data, discrete_columns) | ||
|
|
||
|
|
@@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): | |
| DeprecationWarning | ||
| ) | ||
|
|
||
| self._transformer = DataTransformer() | ||
| self._transformer = DataTransformer(**data_transformer_params) | ||
| self._transformer.fit(train_data, discrete_columns) | ||
|
|
||
| train_data = self._transformer.transform(train_data) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,3 +184,14 @@ def test_wrong_sampling_conditions(): | |
|
|
||
| with pytest.raises(ValueError): | ||
| ctgan.sample(1, 'discrete', "d") | ||
|
|
||
|
|
||
| def test_ctgan_data_transformer_params(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your |
||
| data = pd.DataFrame({ | ||
| 'continuous': np.random.random(1000) | ||
| }) | ||
|
|
||
| ctgan = CTGANSynthesizer(epochs=1) | ||
| ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100}) | ||
|
|
||
| assert ctgan._transformer._max_gm_samples == 100 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that when it comes to this kind of line breaking this indentation is better: