@@ -47,8 +47,16 @@ class BinaryConcreteStochasticGates(StochasticGatesBase):
47
47
Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger
48
48
than 1 back to 0 and 1.
49
49
50
- More details can be found in the
51
- `original paper <https://arxiv.org/abs/1712.01312>`.
50
+ More details can be found in the original paper:
51
+ https://arxiv.org/abs/1712.01312
52
+
53
+ Examples::
54
+
55
+ >>> n_params = 5 # number of parameters
56
+ >>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01)
57
+ >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3
58
+ >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
59
+
52
60
"""
53
61
54
62
def __init__ (
@@ -66,42 +74,42 @@ def __init__(
66
74
Args:
67
75
n_gates (int): number of gates.
68
76
69
- mask (Optional[ Tensor] ): If provided, this allows grouping multiple
77
+ mask (Tensor, optional ): If provided, this allows grouping multiple
70
78
input tensor elements to share the same stochastic gate.
71
79
This tensor should be broadcastable to match the input shape
72
80
and contain integers in the range 0 to n_gates - 1.
73
81
Indices grouped to the same stochastic gate should have the same value.
74
82
If not provided, each element in the input tensor
75
- (on dimensions other than dim 0 - batch dim) is gated separately.
83
+ (on dimensions other than dim 0, i.e., batch dim) is gated separately.
76
84
Default: None
77
85
78
- reg_weight (Optional[ float] ): rescaling weight for L0 regularization term.
86
+ reg_weight (float, optional ): rescaling weight for L0 regularization term.
79
87
Default: 1.0
80
88
81
- temperature (float): temperature of the concrete distribution, controls
82
- the degree of approximation, as 0 means the original Bernoulli
89
+ temperature (float, optional ): temperature of the concrete distribution,
90
+ controls the degree of approximation, as 0 means the original Bernoulli
83
91
without relaxation. The value should be between 0 and 1.
84
92
Default: 2/3
85
93
86
- lower_bound (float): the lower bound to "stretch" the binary concrete
87
- distribution
94
+ lower_bound (float, optional ): the lower bound to "stretch" the binary
95
+ concrete distribution
88
96
Default: -0.1
89
97
90
- upper_bound (float): the upper bound to "stretch" the binary concrete
91
- distribution
98
+ upper_bound (float, optional ): the upper bound to "stretch" the binary
99
+ concrete distribution
92
100
Default: 1.1
93
101
94
- eps (float): term to improve numerical stability in binary concerete
95
- sampling
102
+ eps (float, optional ): term to improve numerical stability in binary
103
+ concerete sampling
96
104
Default: 1e-8
97
105
98
- reg_reduction (str, optional): the reduction to apply to
99
- the regularization: 'none'| 'mean'| 'sum'. 'none': no reduction will be
100
- applied and it will be the same as the return of get_active_probs,
101
- 'mean': the sum of the gates non-zero probabilities will be divided by
102
- the number of gates, 'sum': the gates non-zero probabilities will
106
+ reg_reduction (str, optional): the reduction to apply to the regularization:
107
+ `` 'none'`` | `` 'mean'`` | `` 'sum'``. `` 'none'`` : no reduction will be
108
+ applied and it will be the same as the return of `` get_active_probs`` ,
109
+ `` 'mean'`` : the sum of the gates non-zero probabilities will be divided
110
+ by the number of gates, `` 'sum'`` : the gates non-zero probabilities will
103
111
be summed.
104
- Default: 'sum'
112
+ Default: `` 'sum'``
105
113
"""
106
114
super ().__init__ (
107
115
n_gates , mask = mask , reg_weight = reg_weight , reg_reduction = reg_reduction
@@ -193,7 +201,7 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
193
201
log_alpha_param (Tensor): FloatTensor containing weights for
194
202
the pretrained log_alpha
195
203
196
- mask (Optional[ Tensor] ): If provided, this allows grouping multiple
204
+ mask (Tensor, optional ): If provided, this allows grouping multiple
197
205
input tensor elements to share the same stochastic gate.
198
206
This tensor should be broadcastable to match the input shape
199
207
and contain integers in the range 0 to n_gates - 1.
@@ -202,26 +210,34 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
202
210
(on dimensions other than dim 0 - batch dim) is gated separately.
203
211
Default: None
204
212
205
- reg_weight (Optional[ float] ): rescaling weight for L0 regularization term.
213
+ reg_weight (float, optional ): rescaling weight for L0 regularization term.
206
214
Default: 1.0
207
215
208
- temperature (float): temperature of the concrete distribution, controls
209
- the degree of approximation, as 0 means the original Bernoulli
216
+ temperature (float, optional ): temperature of the concrete distribution,
217
+ controls the degree of approximation, as 0 means the original Bernoulli
210
218
without relaxation. The value should be between 0 and 1.
211
219
Default: 2/3
212
220
213
- lower_bound (float): the lower bound to "stretch" the binary concrete
214
- distribution
221
+ lower_bound (float, optional ): the lower bound to "stretch" the binary
222
+ concrete distribution
215
223
Default: -0.1
216
224
217
- upper_bound (float): the upper bound to "stretch" the binary concrete
218
- distribution
225
+ upper_bound (float, optional ): the upper bound to "stretch" the binary
226
+ concrete distribution
219
227
Default: 1.1
220
228
221
- eps (float): term to improve numerical stability in binary concerete
222
- sampling
229
+ eps (float, optional ): term to improve numerical stability in binary
230
+ concerete sampling
223
231
Default: 1e-8
224
232
233
+ reg_reduction (str, optional): the reduction to apply to the regularization:
234
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
235
+ applied and it will be the same as the return of ``get_active_probs``,
236
+ ``'mean'``: the sum of the gates non-zero probabilities will be divided
237
+ by the number of gates, ``'sum'``: the gates non-zero probabilities will
238
+ be summed.
239
+ Default: ``'sum'``
240
+
225
241
Returns:
226
242
stg (BinaryConcreteStochasticGates): StochasticGates instance
227
243
"""
0 commit comments