diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 5cbf339d..7e7cd440 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -71,6 +71,8 @@ def is_discrete_column(column_info): self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob self._discrete_column_cond_st[current_id] = current_cond_st self._discrete_column_n_category[current_id] = span_info.dim + self._discrete_column_matrix_st[current_id] = st + current_cond_st += span_info.dim current_id += 1 st = ed @@ -150,7 +152,7 @@ def dim_cond_vec(self): def generate_cond_from_condition_column_info(self, condition_info, batch): """Generate the condition vector.""" vec = np.zeros((batch, self._n_categories), dtype='float32') - id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] + id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']] id_ += condition_info['value_id'] vec[:, id_] = 1 return vec diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 0de69232..081906ec 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -440,6 +440,8 @@ def sample(self, n, condition_column=None, condition_value=None): Returns: numpy.ndarray or pandas.DataFrame """ + self._generator.eval() + if condition_column is not None and condition_value is not None: condition_info = self._transformer.convert_column_name_value_to_id( condition_column, condition_value) @@ -467,8 +469,10 @@ def sample(self, n, condition_column=None, condition_value=None): c1 = torch.from_numpy(c1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) - fake = self._generator(fakez) - fakeact = self._apply_activate(fake) + with torch.no_grad(): + fake = self._generator(fakez) + fakeact = self._apply_activate(fake) + data.append(fakeact.detach().cpu().numpy()) data = np.concatenate(data, axis=0) diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index a750d3da..5c71a30b 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -80,19 +80,19 @@ def test_log_frequency(): discrete_columns = ['discrete'] - ctgan = CTGANSynthesizer(epochs=100) + ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) - sampled = ctgan.sample(10000) - counts = sampled['discrete'].value_counts() - assert counts['a'] < 6500 + assert ctgan._data_sampler._discrete_column_category_prob[0][0] < 0.95 + assert ctgan._data_sampler._discrete_column_category_prob[0][1] > 0.025 + assert ctgan._data_sampler._discrete_column_category_prob[0][2] > 0.025 - ctgan = CTGANSynthesizer(log_frequency=False, epochs=100) + ctgan = CTGANSynthesizer(log_frequency=False, epochs=1) ctgan.fit(data, discrete_columns) - sampled = ctgan.sample(10000) - counts = sampled['discrete'].value_counts() - assert counts['a'] > 9000 + assert ctgan._data_sampler._discrete_column_category_prob[0][0] == 0.95 + assert ctgan._data_sampler._discrete_column_category_prob[0][1] == 0.025 + assert ctgan._data_sampler._discrete_column_category_prob[0][2] == 0.025 def test_categorical_nan(): @@ -134,6 +134,33 @@ def test_synthesizer_sample(): assert isinstance(samples, pd.DataFrame) +def test_synthesizer_sampling(): + """Test the CTGANSynthesizer sampling.""" + data = pd.DataFrame({ + 'continuous': np.random.random(1000), + 'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25]) + }) + + discrete_columns = ['discrete'] + + ctgan = CTGANSynthesizer(epochs=100) + ctgan.fit(data, discrete_columns) + + samples = ctgan.sample(1000) + assert samples['discrete'].value_counts()['a'] > 800 + assert samples['discrete'].value_counts()['b'] < 100 + assert samples['discrete'].value_counts()['c'] < 100 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='a') + assert samples['discrete'].value_counts()['a'] > 750 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='b') + assert samples['discrete'].value_counts()['b'] > 750 + + samples = ctgan.sample(1000, condition_column='discrete', condition_value='c') + assert samples['discrete'].value_counts()['c'] > 750 + + def test_save_load(): """Test the CTGANSynthesizer load/save methods.""" data = pd.DataFrame({