From 4d2a2347a0d718a648ec72eab49daca9e31f6834 Mon Sep 17 00:00:00 2001 From: Mazen Ali Date: Thu, 21 Nov 2024 11:10:22 +0100 Subject: [PATCH] fix (ctgan): add discriminator to model attributes --- ctgan/synthesizers/ctgan.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 5fdbc269..4357a111 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -190,6 +190,7 @@ def __init__( self._transformer = None self._data_sampler = None self._generator = None + self._discriminator = None self.loss_values = None @staticmethod @@ -330,7 +331,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim ).to(self._device) - discriminator = Discriminator( + self._discriminator = Discriminator( data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac ).to(self._device) @@ -342,7 +343,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): ) optimizerD = optim.Adam( - discriminator.parameters(), + self._discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay, @@ -395,10 +396,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None): real_cat = real fake_cat = fakeact - y_fake = discriminator(fake_cat) - y_real = discriminator(real_cat) + y_fake = self._discriminator(fake_cat) + y_real = self._discriminator(real_cat) - pen = discriminator.calc_gradient_penalty( + pen = self._discriminator.calc_gradient_penalty( real_cat, fake_cat, self._device, self.pac ) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) @@ -423,9 +424,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None): fakeact = self._apply_activate(fake) if c1 is not None: - y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) + y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1)) else: - y_fake = discriminator(fakeact) + y_fake = self._discriminator(fakeact) if condvec is None: cross_entropy = 0 @@ -520,3 +521,5 @@ def set_device(self, device): self._device = device if self._generator is not None: self._generator.to(self._device) + if self._discriminator is not None: + self._discriminator.to(self._device)