1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
- from typing import Optional
4
+ from typing import Callable , Optional
5
5
6
6
import torch
7
7
from nflows .nn .nde .made import MADE
15
15
16
16
17
17
class CategoricalMADE (MADE ):
18
+ """Conditional density (mass) estimation for a n-dim categorical random variable.
19
+
20
+ Takes as input parameters theta and learns the parameters p of a Categorical.
21
+
22
+ Defines log prob and sample functions.
23
+ """
24
+
18
25
def __init__ (
19
26
self ,
20
- categories , # Tensor[int]
21
- hidden_features ,
22
- context_features = None ,
23
- num_blocks = 2 ,
24
- use_residual_blocks = True ,
25
- random_mask = False ,
26
- activation = F .relu ,
27
- dropout_probability = 0.0 ,
28
- use_batch_norm = False ,
29
- epsilon = 1e-2 ,
30
- custom_initialization = True ,
27
+ num_categories : Tensor , # Tensor[int]
28
+ hidden_features : int ,
29
+ context_features : Optional [ int ] = None ,
30
+ num_blocks : int = 2 ,
31
+ use_residual_blocks : bool = True ,
32
+ random_mask : bool = False ,
33
+ activation : Callable = F .relu ,
34
+ dropout_probability : float = 0.0 ,
35
+ use_batch_norm : bool = False ,
36
+ epsilon : float = 1e-2 ,
37
+ custom_initialization : bool = True ,
31
38
embedding_net : Optional [nn .Module ] = nn .Identity (),
32
39
):
40
+ """Initialize the neural net.
41
+
42
+ Args:
43
+ num_categories: number of categories for each variable. len(categories)
44
+ defines the number of input units, i.e., dimensionality of the features.
45
+ max(categories) defines the number of output units, i.e., the largest
46
+ number of categories.
47
+ num_hidden: number of hidden units per layer.
48
+ num_layers: number of hidden layers.
49
+ embedding_net: emebedding net for input.
50
+ """
33
51
if use_residual_blocks and random_mask :
34
52
raise ValueError ("Residual blocks can't be used with random masks." )
35
53
36
- self .num_variables = len (categories )
37
- self .num_categories = int (max (categories ))
38
- self .categories = categories
54
+ self .num_variables = len (num_categories )
55
+ self .num_categories = int (torch .max (num_categories ))
39
56
self .mask = torch .zeros (self .num_variables , self .num_categories )
40
- for i , c in enumerate (categories ):
57
+ for i , c in enumerate (num_categories ):
41
58
self .mask [i , :c ] = 1
42
59
43
60
super ().__init__ (
@@ -60,7 +77,18 @@ def __init__(
60
77
if custom_initialization :
61
78
self ._initialize ()
62
79
63
- def forward (self , inputs , context = None ):
80
+ def forward (self , inputs : Tensor , context : Optional [Tensor ] = None ) -> Tensor :
81
+ r"""Forward pass of the categorical density estimator network to compute the
82
+ conditional density at a given time.
83
+
84
+ Args:
85
+ input: Original data, x0. (batch_size, *input_shape)
86
+ condition: Conditioning variable. (batch_size, *condition_shape)
87
+
88
+ Returns:
89
+ Predicted categorical probabilities. (batch_size, *input_shape,
90
+ num_categories)
91
+ """
64
92
embedded_context = self .embedding_net .forward (context )
65
93
return super ().forward (inputs , context = embedded_context )
66
94
@@ -69,8 +97,16 @@ def compute_probs(self, outputs):
69
97
ps = ps / ps .sum (dim = - 1 , keepdim = True )
70
98
return ps
71
99
72
- # outputs (batch_size, num_variables, num_categories)
73
- def log_prob (self , inputs , context = None ):
100
+ def log_prob (self , inputs : Tensor , context : Optional [Tensor ] = None ) -> Tensor :
101
+ r"""Return log-probability of samples.
102
+
103
+ Args:
104
+ input: Input datapoints of shape `(batch_size, *input_shape)`.
105
+ context: Context of shape `(batch_size, *condition_shape)`.
106
+
107
+ Returns:
108
+ Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
109
+ """
74
110
outputs = self .forward (inputs , context = context )
75
111
outputs = outputs .reshape (* inputs .shape , self .num_categories )
76
112
ps = self .compute_probs (outputs )
0 commit comments