Skip to content

Commit 040b5e5

Browse files
authored
Merge pull request #28 from Multiomics-Analytics-Group/27-ecoli-vmm-prior-not-working
Fixed all bugs with VMMPrior
2 parents 3c63675 + 3b3f54a commit 040b5e5

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

src/abaco/ABaCo.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3864,7 +3864,7 @@ def kl_divergence(self, z, k_ohe=None):
38643864
q_zx = self.encoder(z)
38653865

38663866
# 2. Get the prior distribution
3867-
if isinstance(self.prior, MoCPPrior):
3867+
if isinstance(self.prior, (MoCPPrior, VMMPrior)):
38683868
p_z = self.prior(k_ohe) # select Gaussian component of the prior
38693869
else:
38703870
p_z = self.prior() # sample from standard Gaussian prior
@@ -4016,7 +4016,7 @@ def __init__(
40164016
self.prior = MoCPPrior(d_z, n_bios)
40174017

40184018
elif prior == "VMM":
4019-
self.encoder = MoGEncoder(nn.Sequential(*modules), n_bios)
4019+
self.encoder = MoCPEncoder(nn.Sequential(*modules), n_bios)
40204020
self.prior = VMMPrior(
40214021
d_z,
40224022
n_features,
@@ -4161,7 +4161,7 @@ def train_vae(
41614161
bio_penalty = 0.0 # ensures points from the same biological group to be mapped on the same cluster
41624162
cluster_penalty = 0.0 # ensures gaussian components to not overlap
41634163

4164-
if isinstance(self.vae.prior, MoCPPrior):
4164+
if isinstance(self.vae.prior, (MoCPPrior, VMMPrior)):
41654165
# Compute penalty for biological mapping
41664166
pred_bio, _, _ = self.vae.encoder.encode(
41674167
torch.cat([x, ohe_batch], dim=1)
@@ -4280,7 +4280,7 @@ def batch_correct(
42804280
bio_penalty = 0.0 # ensures points from the same biological group to be mapped on the same cluster
42814281
cluster_penalty = 0.0 # ensures gaussian components to not overlap
42824282

4283-
if isinstance(self.vae.prior, MoCPPrior):
4283+
if isinstance(self.vae.prior, (MoCPPrior, VMMPrior)):
42844284
# Compute penalty for biological mapping
42854285
pred_bio, _, _ = self.vae.encoder.encode(
42864286
torch.cat([x, ohe_batch], dim=1)
@@ -4290,11 +4290,6 @@ def batch_correct(
42904290
pred_bio, ohe_bio.argmax(dim=1)
42914291
)
42924292

4293-
# Compute penalty for group clusters
4294-
cluster_penalty += (
4295-
w_cluster_penalty
4296-
) * self.vae.prior.cluster_loss()
4297-
42984293
# Total loss is reconstruction loss + KL divergence loss
42994294
vae_loss = recon_loss + kl_loss + bio_penalty + cluster_penalty
43004295

0 commit comments

Comments
 (0)