@@ -29,15 +29,16 @@ def __init__(
29
29
use_batch_norm = False ,
30
30
epsilon = 1e-2 ,
31
31
custom_initialization = True ,
32
+ #TODO: embedding_net: Optional[nn.Module] = None,
32
33
):
33
34
34
35
if use_residual_blocks and random_mask :
35
36
raise ValueError ("Residual blocks can't be used with random masks." )
36
37
37
38
self .num_variables = len (categories )
38
- self .max_categories = max (categories )
39
+ self .num_categories = int ( max (categories ) )
39
40
self .categories = categories
40
- self .mask = torch .zeros (self .num_variables , self .max_categories )
41
+ self .mask = torch .zeros (self .num_variables , self .num_categories )
41
42
for i , c in enumerate (categories ):
42
43
self .mask [i , :c ] = 1
43
44
@@ -46,7 +47,7 @@ def __init__(
46
47
hidden_features ,
47
48
context_features = context_features ,
48
49
num_blocks = num_blocks ,
49
- output_multiplier = self .max_categories ,
50
+ output_multiplier = self .num_categories ,
50
51
use_residual_blocks = use_residual_blocks ,
51
52
random_mask = random_mask ,
52
53
activation = activation ,
@@ -68,10 +69,10 @@ def compute_probs(self, outputs):
68
69
ps = ps / ps .sum (dim = - 1 , keepdim = True )
69
70
return ps
70
71
71
- # outputs (batch_size, num_variables, max_categories )
72
+ # outputs (batch_size, num_variables, num_categories )
72
73
def log_prob (self , inputs , context = None ):
73
74
outputs = self .forward (inputs , context = context )
74
- outputs = outputs .reshape (* inputs .shape , self .max_categories )
75
+ outputs = outputs .reshape (* inputs .shape , self .num_categories )
75
76
ps = self .compute_probs (outputs )
76
77
77
78
# categorical log prob
@@ -91,7 +92,7 @@ def sample(self, num_samples, context=None):
91
92
92
93
for variable in range (self .num_variables ):
93
94
outputs = self .forward (samples , context )
94
- outputs = outputs .reshape (* samples .shape , self .max_categories )
95
+ outputs = outputs .reshape (* samples .shape , self .num_categories )
95
96
ps = self .compute_probs (outputs )
96
97
samples [:, variable ] = Categorical (probs = ps [:,variable ]).sample ()
97
98
0 commit comments