Skip to content

Commit 2c9dcc1

Browse files
aobo-yfacebook-github-bot
authored andcommitted
improve documentation of STG (#1100)
Summary: - correct arg type `optional` based on our convention - code block highlight - add example Pull Request resolved: #1100 Reviewed By: cyrjano Differential Revision: D42493054 Pulled By: aobo-y fbshipit-source-id: 9491d0202a9bcd73ace93482500ffc7ca902c819
1 parent cc5f468 commit 2c9dcc1

File tree

3 files changed

+88
-57
lines changed

3 files changed

+88
-57
lines changed

captum/module/binary_concrete_stochastic_gates.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,16 @@ class BinaryConcreteStochasticGates(StochasticGatesBase):
4747
Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger
4848
than 1 back to 0 and 1.
4949
50-
More details can be found in the
51-
`original paper <https://arxiv.org/abs/1712.01312>`.
50+
More details can be found in the original paper:
51+
https://arxiv.org/abs/1712.01312
52+
53+
Examples::
54+
55+
>>> n_params = 5 # number of parameters
56+
>>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01)
57+
>>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3
58+
>>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
59+
5260
"""
5361

5462
def __init__(
@@ -66,42 +74,42 @@ def __init__(
6674
Args:
6775
n_gates (int): number of gates.
6876
69-
mask (Optional[Tensor]): If provided, this allows grouping multiple
77+
mask (Tensor, optional): If provided, this allows grouping multiple
7078
input tensor elements to share the same stochastic gate.
7179
This tensor should be broadcastable to match the input shape
7280
and contain integers in the range 0 to n_gates - 1.
7381
Indices grouped to the same stochastic gate should have the same value.
7482
If not provided, each element in the input tensor
75-
(on dimensions other than dim 0 - batch dim) is gated separately.
83+
(on dimensions other than dim 0, i.e., batch dim) is gated separately.
7684
Default: None
7785
78-
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
86+
reg_weight (float, optional): rescaling weight for L0 regularization term.
7987
Default: 1.0
8088
81-
temperature (float): temperature of the concrete distribution, controls
82-
the degree of approximation, as 0 means the original Bernoulli
89+
temperature (float, optional): temperature of the concrete distribution,
90+
controls the degree of approximation, as 0 means the original Bernoulli
8391
without relaxation. The value should be between 0 and 1.
8492
Default: 2/3
8593
86-
lower_bound (float): the lower bound to "stretch" the binary concrete
87-
distribution
94+
lower_bound (float, optional): the lower bound to "stretch" the binary
95+
concrete distribution
8896
Default: -0.1
8997
90-
upper_bound (float): the upper bound to "stretch" the binary concrete
91-
distribution
98+
upper_bound (float, optional): the upper bound to "stretch" the binary
99+
concrete distribution
92100
Default: 1.1
93101
94-
eps (float): term to improve numerical stability in binary concerete
95-
sampling
102+
eps (float, optional): term to improve numerical stability in binary
103+
concerete sampling
96104
Default: 1e-8
97105
98-
reg_reduction (str, optional): the reduction to apply to
99-
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
100-
applied and it will be the same as the return of get_active_probs,
101-
'mean': the sum of the gates non-zero probabilities will be divided by
102-
the number of gates, 'sum': the gates non-zero probabilities will
106+
reg_reduction (str, optional): the reduction to apply to the regularization:
107+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
108+
applied and it will be the same as the return of ``get_active_probs``,
109+
``'mean'``: the sum of the gates non-zero probabilities will be divided
110+
by the number of gates, ``'sum'``: the gates non-zero probabilities will
103111
be summed.
104-
Default: 'sum'
112+
Default: ``'sum'``
105113
"""
106114
super().__init__(
107115
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
@@ -193,7 +201,7 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
193201
log_alpha_param (Tensor): FloatTensor containing weights for
194202
the pretrained log_alpha
195203
196-
mask (Optional[Tensor]): If provided, this allows grouping multiple
204+
mask (Tensor, optional): If provided, this allows grouping multiple
197205
input tensor elements to share the same stochastic gate.
198206
This tensor should be broadcastable to match the input shape
199207
and contain integers in the range 0 to n_gates - 1.
@@ -202,26 +210,34 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
202210
(on dimensions other than dim 0 - batch dim) is gated separately.
203211
Default: None
204212
205-
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
213+
reg_weight (float, optional): rescaling weight for L0 regularization term.
206214
Default: 1.0
207215
208-
temperature (float): temperature of the concrete distribution, controls
209-
the degree of approximation, as 0 means the original Bernoulli
216+
temperature (float, optional): temperature of the concrete distribution,
217+
controls the degree of approximation, as 0 means the original Bernoulli
210218
without relaxation. The value should be between 0 and 1.
211219
Default: 2/3
212220
213-
lower_bound (float): the lower bound to "stretch" the binary concrete
214-
distribution
221+
lower_bound (float, optional): the lower bound to "stretch" the binary
222+
concrete distribution
215223
Default: -0.1
216224
217-
upper_bound (float): the upper bound to "stretch" the binary concrete
218-
distribution
225+
upper_bound (float, optional): the upper bound to "stretch" the binary
226+
concrete distribution
219227
Default: 1.1
220228
221-
eps (float): term to improve numerical stability in binary concerete
222-
sampling
229+
eps (float, optional): term to improve numerical stability in binary
230+
concerete sampling
223231
Default: 1e-8
224232
233+
reg_reduction (str, optional): the reduction to apply to the regularization:
234+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
235+
applied and it will be the same as the return of ``get_active_probs``,
236+
``'mean'``: the sum of the gates non-zero probabilities will be divided
237+
by the number of gates, ``'sum'``: the gates non-zero probabilities will
238+
be summed.
239+
Default: ``'sum'``
240+
225241
Returns:
226242
stg (BinaryConcreteStochasticGates): StochasticGates instance
227243
"""

captum/module/gaussian_stochastic_gates.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@ class GaussianStochasticGates(StochasticGatesBase):
2828
within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification
2929
is used to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1.
3030
31-
More details can be found in the
32-
`original paper <https://arxiv.org/abs/1810.04247>`.
31+
More details can be found in the original paper:
32+
https://arxiv.org/abs/1810.04247
33+
34+
Examples::
35+
36+
>>> n_params = 5 # number of gates
37+
>>> stg = GaussianStochasticGates(n_params, reg_weight=0.01)
38+
>>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3
39+
>>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
3340
"""
3441

3542
def __init__(
@@ -44,28 +51,28 @@ def __init__(
4451
Args:
4552
n_gates (int): number of gates.
4653
47-
mask (Optional[Tensor]): If provided, this allows grouping multiple
54+
mask (Tensor, optional): If provided, this allows grouping multiple
4855
input tensor elements to share the same stochastic gate.
4956
This tensor should be broadcastable to match the input shape
5057
and contain integers in the range 0 to n_gates - 1.
5158
Indices grouped to the same stochastic gate should have the same value.
5259
If not provided, each element in the input tensor
53-
(on dimensions other than dim 0 - batch dim) is gated separately.
60+
(on dimensions other than dim 0, i.e., batch dim) is gated separately.
5461
Default: None
5562
56-
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
63+
reg_weight (float, optional): rescaling weight for L0 regularization term.
5764
Default: 1.0
5865
59-
std (Optional[float]): standard deviation that will be fixed throughout.
60-
Default: 0.5 (by paper reference)
66+
std (float, optional): standard deviation that will be fixed throughout.
67+
Default: 0.5
6168
62-
reg_reduction (str, optional): the reduction to apply to
63-
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
64-
applied and it will be the same as the return of get_active_probs,
65-
'mean': the sum of the gates non-zero probabilities will be divided by
66-
the number of gates, 'sum': the gates non-zero probabilities will
69+
reg_reduction (str, optional): the reduction to apply to the regularization:
70+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
71+
applied and it will be the same as the return of ``get_active_probs``,
72+
``'mean'``: the sum of the gates non-zero probabilities will be divided
73+
by the number of gates, ``'sum'``: the gates non-zero probabilities will
6774
be summed.
68-
Default: 'sum'
75+
Default: ``'sum'``
6976
"""
7077
super().__init__(
7178
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
@@ -126,7 +133,7 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs):
126133
Args:
127134
mu (Tensor): FloatTensor containing weights for the pretrained mu
128135
129-
mask (Optional[Tensor]): If provided, this allows grouping multiple
136+
mask (Tensor, optional): If provided, this allows grouping multiple
130137
input tensor elements to share the same stochastic gate.
131138
This tensor should be broadcastable to match the input shape
132139
and contain integers in the range 0 to n_gates - 1.
@@ -135,11 +142,19 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs):
135142
(on dimensions other than dim 0 - batch dim) is gated separately.
136143
Default: None
137144
138-
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
145+
reg_weight (float, optional): rescaling weight for L0 regularization term.
139146
Default: 1.0
140147
141-
std (Optional[float]): standard deviation that will be fixed throughout.
142-
Default: 0.5 (by paper reference)
148+
std (float, optional): standard deviation that will be fixed throughout.
149+
Default: 0.5
150+
151+
reg_reduction (str, optional): the reduction to apply to the regularization:
152+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
153+
applied and it will be the same as the return of ``get_active_probs``,
154+
``'mean'``: the sum of the gates non-zero probabilities will be divided
155+
by the number of gates, ``'sum'``: the gates non-zero probabilities will
156+
be summed.
157+
Default: ``'sum'``
143158
144159
Returns:
145160
stg (GaussianStochasticGates): StochasticGates instance

captum/module/stochastic_gates_base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
Args:
4040
n_gates (int): number of gates.
4141
42-
mask (Optional[Tensor]): If provided, this allows grouping multiple
42+
mask (Tensor, optional): If provided, this allows grouping multiple
4343
input tensor elements to share the same stochastic gate.
4444
This tensor should be broadcastable to match the input shape
4545
and contain integers in the range 0 to n_gates - 1.
@@ -48,16 +48,16 @@ def __init__(
4848
(on dimensions other than dim 0 - batch dim) is gated separately.
4949
Default: None
5050
51-
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
51+
reg_weight (float, optional): rescaling weight for L0 regularization term.
5252
Default: 1.0
5353
54-
reg_reduction (str, optional): the reduction to apply to
55-
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
56-
applied and it will be the same as the return of get_active_probs,
57-
'mean': the sum of the gates non-zero probabilities will be divided by
58-
the number of gates, 'sum': the gates non-zero probabilities will
54+
reg_reduction (str, optional): the reduction to apply to the regularization:
55+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
56+
applied and it will be the same as the return of ``get_active_probs``,
57+
``'mean'``: the sum of the gates non-zero probabilities will be divided
58+
by the number of gates, ``'sum'``: the gates non-zero probabilities will
5959
be summed.
60-
Default: 'sum'
60+
Default: ``'sum'``
6161
"""
6262
super().__init__()
6363

@@ -143,13 +143,13 @@ def get_gate_values(self, clamp: bool = True) -> Tensor:
143143
optionally clamped within 0 and 1.
144144
145145
Args:
146-
clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli
147-
variables, gate values are clamped within 0 and 1 by default.
146+
clamp (bool, optional): whether to clamp the gate values or not. As smoothed
147+
Bernoulli variables, gate values are clamped within 0 and 1 by default.
148148
Turn this off to get the raw means of the underneath
149149
distribution (e.g., concrete, gaussian), which can be useful to
150150
differentiate the gates' importance when multiple gate
151151
values are beyond 0 or 1.
152-
Default: True
152+
Default: ``True``
153153
154154
Returns:
155155
Tensor:

0 commit comments

Comments
 (0)