From 736ebdd8b6e56fc7a24d01db0cea8e656ed8057c Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Wed, 20 Jul 2022 14:26:40 +0200 Subject: [PATCH 1/7] fix bugs --- ctgan/data_sampler.py | 2 +- ctgan/synthesizers/ctgan.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 5cbf339d..8a0a956b 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -150,7 +150,7 @@ def dim_cond_vec(self): def generate_cond_from_condition_column_info(self, condition_info, batch): """Generate the condition vector.""" vec = np.zeros((batch, self._n_categories), dtype='float32') - id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] + id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']] id_ += condition_info['value_id'] vec[:, id_] = 1 return vec diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 0de69232..22a13327 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -440,6 +440,8 @@ def sample(self, n, condition_column=None, condition_value=None): Returns: numpy.ndarray or pandas.DataFrame """ + self._generator.eval() + if condition_column is not None and condition_value is not None: condition_info = self._transformer.convert_column_name_value_to_id( condition_column, condition_value) From 3d492e8e041454229eb11493effa526c1822122c Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Wed, 20 Jul 2022 14:43:08 +0200 Subject: [PATCH 2/7] add torch.no_grad for performance --- ctgan/synthesizers/ctgan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 22a13327..081906ec 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -469,8 +469,10 @@ def sample(self, n, condition_column=None, condition_value=None): c1 = torch.from_numpy(c1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) - fake = self._generator(fakez) - fakeact = self._apply_activate(fake) + with torch.no_grad(): + fake = self._generator(fakez) + fakeact = self._apply_activate(fake) + data.append(fakeact.detach().cpu().numpy()) data = np.concatenate(data, axis=0) From 8ce3cf78ae5c6563b8026a45e3fa5280863fc72b Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Thu, 21 Jul 2022 16:45:38 +0200 Subject: [PATCH 3/7] proposed fix in #169 --- ctgan/data_sampler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 8a0a956b..d9c9f4e5 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -59,11 +59,13 @@ def is_discrete_column(column_info): st = 0 current_id = 0 + discrete_st = 0 current_cond_st = 0 for column_info in output_info: if is_discrete_column(column_info): span_info = column_info[0] ed = st + span_info.dim + discrete_ed = discrete_st + span_info.dim category_freq = np.sum(data[:, st:ed], axis=0) if log_frequency: category_freq = np.log(category_freq + 1) @@ -71,9 +73,13 @@ def is_discrete_column(column_info): self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob self._discrete_column_cond_st[current_id] = current_cond_st self._discrete_column_n_category[current_id] = span_info.dim + + self._discrete_column_matrix_st[current_id] = discrete_st + current_cond_st += span_info.dim current_id += 1 st = ed + discrete_st = discrete_ed else: st += sum([span_info.dim for span_info in column_info]) @@ -150,7 +156,7 @@ def dim_cond_vec(self): def generate_cond_from_condition_column_info(self, condition_info, batch): """Generate the condition vector.""" vec = np.zeros((batch, self._n_categories), dtype='float32') - id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']] + id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] id_ += condition_info['value_id'] vec[:, id_] = 1 return vec From b401998d8d1597b0e22a4c1e06907d91f30a2dd2 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Mon, 25 Jul 2022 09:39:28 +0200 Subject: [PATCH 4/7] undo the proposed fix in #169 --- ctgan/data_sampler.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index d9c9f4e5..8a0a956b 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -59,13 +59,11 @@ def is_discrete_column(column_info): st = 0 current_id = 0 - discrete_st = 0 current_cond_st = 0 for column_info in output_info: if is_discrete_column(column_info): span_info = column_info[0] ed = st + span_info.dim - discrete_ed = discrete_st + span_info.dim category_freq = np.sum(data[:, st:ed], axis=0) if log_frequency: category_freq = np.log(category_freq + 1) @@ -73,13 +71,9 @@ def is_discrete_column(column_info): self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob self._discrete_column_cond_st[current_id] = current_cond_st self._discrete_column_n_category[current_id] = span_info.dim - - self._discrete_column_matrix_st[current_id] = discrete_st - current_cond_st += span_info.dim current_id += 1 st = ed - discrete_st = discrete_ed else: st += sum([span_info.dim for span_info in column_info]) @@ -156,7 +150,7 @@ def dim_cond_vec(self): def generate_cond_from_condition_column_info(self, condition_info, batch): """Generate the condition vector.""" vec = np.zeros((batch, self._n_categories), dtype='float32') - id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] + id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']] id_ += condition_info['value_id'] vec[:, id_] = 1 return vec From 0bc340938f12df754520a71858dbd1a3ff4aedd4 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Mon, 25 Jul 2022 09:46:34 +0200 Subject: [PATCH 5/7] new fix based on #169 --- ctgan/data_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 8a0a956b..7e7cd440 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -71,6 +71,8 @@ def is_discrete_column(column_info): self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob self._discrete_column_cond_st[current_id] = current_cond_st self._discrete_column_n_category[current_id] = span_info.dim + self._discrete_column_matrix_st[current_id] = st + current_cond_st += span_info.dim current_id += 1 st = ed From af46041354407b23c1879b2a3f94cc49903bc165 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Mon, 25 Jul 2022 09:59:37 +0200 Subject: [PATCH 6/7] Change (outdated) test --- tests/integration/synthesizer/test_ctgan.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index a750d3da..4846812e 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -80,19 +80,19 @@ def test_log_frequency(): discrete_columns = ['discrete'] - ctgan = CTGANSynthesizer(epochs=100) + ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) - sampled = ctgan.sample(10000) - counts = sampled['discrete'].value_counts() - assert counts['a'] < 6500 + assert ctgan._data_sampler._discrete_column_category_prob[0][0] < 0.95 + assert ctgan._data_sampler._discrete_column_category_prob[0][1] > 0.025 + assert ctgan._data_sampler._discrete_column_category_prob[0][2] > 0.025 - ctgan = CTGANSynthesizer(log_frequency=False, epochs=100) + ctgan = CTGANSynthesizer(log_frequency=False, epochs=1) ctgan.fit(data, discrete_columns) - sampled = ctgan.sample(10000) - counts = sampled['discrete'].value_counts() - assert counts['a'] > 9000 + assert ctgan._data_sampler._discrete_column_category_prob[0][0] == 0.95 + assert ctgan._data_sampler._discrete_column_category_prob[0][1] == 0.025 + assert ctgan._data_sampler._discrete_column_category_prob[0][2] == 0.025 def test_categorical_nan(): From abfb7fcf56bd33cecced51a2f837f12ad6c53673 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Mon, 25 Jul 2022 10:12:37 +0200 Subject: [PATCH 7/7] add sampling test --- tests/integration/synthesizer/test_ctgan.py | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index 4846812e..5c71a30b 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -134,6 +134,33 @@ def test_synthesizer_sample(): assert isinstance(samples, pd.DataFrame) +def test_synthesizer_sampling(): + """Test the CTGANSynthesizer sampling.""" + data = pd.DataFrame({ + 'continuous': np.random.random(1000), + 'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25]) + }) + + discrete_columns = ['discrete'] + + ctgan = CTGANSynthesizer(epochs=100) + ctgan.fit(data, discrete_columns) + + samples = ctgan.sample(1000) + assert samples['discrete'].value_counts()['a'] > 800 + assert samples['discrete'].value_counts()['b'] < 100 + assert samples['discrete'].value_counts()['c'] < 100 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='a') + assert samples['discrete'].value_counts()['a'] > 750 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='b') + assert samples['discrete'].value_counts()['b'] > 750 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='c') + assert samples['discrete'].value_counts()['c'] > 750 + + def test_save_load(): """Test the CTGANSynthesizer load/save methods.""" data = pd.DataFrame({