@@ -3864,7 +3864,7 @@ def kl_divergence(self, z, k_ohe=None):
38643864 q_zx = self .encoder (z )
38653865
38663866 # 2. Get the prior distribution
3867- if isinstance (self .prior , MoCPPrior ):
3867+ if isinstance (self .prior , ( MoCPPrior , VMMPrior ) ):
38683868 p_z = self .prior (k_ohe ) # select Gaussian component of the prior
38693869 else :
38703870 p_z = self .prior () # sample from standard Gaussian prior
@@ -4016,7 +4016,7 @@ def __init__(
40164016 self .prior = MoCPPrior (d_z , n_bios )
40174017
40184018 elif prior == "VMM" :
4019- self .encoder = MoGEncoder (nn .Sequential (* modules ), n_bios )
4019+ self .encoder = MoCPEncoder (nn .Sequential (* modules ), n_bios )
40204020 self .prior = VMMPrior (
40214021 d_z ,
40224022 n_features ,
@@ -4161,7 +4161,7 @@ def train_vae(
41614161 bio_penalty = 0.0 # ensures points from the same biological group to be mapped on the same cluster
41624162 cluster_penalty = 0.0 # ensures gaussian components to not overlap
41634163
4164- if isinstance (self .vae .prior , MoCPPrior ):
4164+ if isinstance (self .vae .prior , ( MoCPPrior , VMMPrior ) ):
41654165 # Compute penalty for biological mapping
41664166 pred_bio , _ , _ = self .vae .encoder .encode (
41674167 torch .cat ([x , ohe_batch ], dim = 1 )
@@ -4280,7 +4280,7 @@ def batch_correct(
42804280 bio_penalty = 0.0 # ensures points from the same biological group to be mapped on the same cluster
42814281 cluster_penalty = 0.0 # ensures gaussian components to not overlap
42824282
4283- if isinstance (self .vae .prior , MoCPPrior ):
4283+ if isinstance (self .vae .prior , ( MoCPPrior , VMMPrior ) ):
42844284 # Compute penalty for biological mapping
42854285 pred_bio , _ , _ = self .vae .encoder .encode (
42864286 torch .cat ([x , ohe_batch ], dim = 1 )
@@ -4290,11 +4290,6 @@ def batch_correct(
42904290 pred_bio , ohe_bio .argmax (dim = 1 )
42914291 )
42924292
4293- # Compute penalty for group clusters
4294- cluster_penalty += (
4295- w_cluster_penalty
4296- ) * self .vae .prior .cluster_loss ()
4297-
42984293 # Total loss is reconstruction loss + KL divergence loss
42994294 vae_loss = recon_loss + kl_loss + bio_penalty + cluster_penalty
43004295
0 commit comments