Skip to content

Commit ecc81e6

Browse files
aobo-yfacebook-github-bot
authored andcommitted
update stg sphinx to include inherited methods (#1095)
Summary: Include inherited methods in sphinx and remove dummy wrapper in children ![Screen Shot 2022-12-16 at 6 46 00 PM](https://user-images.githubusercontent.com/5113450/208220030-50d853a9-bea8-4a4d-bdfc-9673c18fc987.png) Pull Request resolved: #1095 Reviewed By: NarineK Differential Revision: D42117580 Pulled By: aobo-y fbshipit-source-id: 6eace5ff620c5765bff3df511d95a01a9652ff87
1 parent fe13596 commit ecc81e6

File tree

5 files changed

+17
-41
lines changed

5 files changed

+17
-41
lines changed

captum/module/binary_concrete_stochastic_gates.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,22 +133,6 @@ def __init__(
133133
# pre-calculate the fixed term used in active prob
134134
self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound)
135135

136-
def forward(self, *args, **kwargs):
137-
"""
138-
Args:
139-
input_tensor (Tensor): Tensor to be gated with stochastic gates
140-
141-
142-
Outputs:
143-
gated_input (Tensor): Tensor of the same shape weighted by the sampled
144-
gate values
145-
146-
l0_reg (Tensor): L0 regularization term to be optimized together with
147-
model loss,
148-
e.g. loss(model_out, target) + l0_reg
149-
"""
150-
return super().forward(*args, **kwargs)
151-
152136
def _sample_gate_values(self, batch_size: int) -> Tensor:
153137
"""
154138
Sample gate values for each example in the batch from the binary concrete

captum/module/gaussian_stochastic_gates.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,6 @@ def __init__(
7878
assert 0 < std, f"the standard deviation should be positive, received {std}"
7979
self.std = std
8080

81-
def forward(self, *args, **kwargs):
82-
"""
83-
Args:
84-
input_tensor (Tensor): Tensor to be gated with stochastic gates
85-
86-
Outputs:
87-
gated_input (Tensor): Tensor of the same shape weighted by the sampled
88-
gate values
89-
90-
l0_reg (Tensor): L0 regularization term to be optimized together with
91-
model loss,
92-
e.g. loss(model_out, target) + l0_reg
93-
"""
94-
return super().forward(*args, **kwargs)
95-
9681
def _sample_gate_values(self, batch_size: int) -> Tensor:
9782
"""
9883
Sample gate values for each example in the batch from the Gaussian distribution

captum/module/stochastic_gates_base.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,13 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
8787
input_tensor (Tensor): Tensor to be gated with stochastic gates
8888
8989
90-
Outputs:
91-
gated_input (Tensor): Tensor of the same shape weighted by the sampled
90+
Returns:
91+
tuple[Tensor, Tensor]:
92+
93+
- gated_input (Tensor): Tensor of the same shape weighted by the sampled
9294
gate values
9395
94-
l0_reg (Tensor): L0 regularization term to be optimized together with
96+
- l0_reg (Tensor): L0 regularization term to be optimized together with
9597
model loss,
9698
e.g. loss(model_out, target) + l0_reg
9799
"""
@@ -140,16 +142,18 @@ def get_gate_values(self, clamp: bool = True) -> Tensor:
140142
Get the gate values, which are the means of the underneath gate distributions,
141143
optionally clamped within 0 and 1.
142144
143-
Returns:
144-
gate_values (Tensor): value of each gate in shape(n_gates)
145-
146-
clamp (bool): if clamp the gate values. As smoothed Bernoulli
147-
variables, gate values are clamped withn 0 and 1 by defautl.
145+
Args:
146+
clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli
147+
variables, gate values are clamped within 0 and 1 by default.
148148
Turn this off to get the raw means of the underneath
149-
distribution (e.g., conrete, gaussian), which can be useful to
149+
distribution (e.g., concrete, gaussian), which can be useful to
150150
differentiate the gates' importance when multiple gate
151151
values are beyond 0 or 1.
152152
Default: True
153+
154+
Returns:
155+
Tensor:
156+
- gate_values (Tensor): value of each gate in shape(n_gates)
153157
"""
154158
gate_values = self._get_gate_values()
155159
if clamp:
@@ -162,7 +166,8 @@ def get_gate_active_probs(self) -> Tensor:
162166
Get the active probability of each gate, i.e, gate value > 0
163167
164168
Returns:
165-
probs (Tensor): probabilities tensor of the gates are active
169+
Tensor:
170+
- probs (Tensor): probabilities tensor of the gates are active
166171
in shape(n_gates)
167172
"""
168173
return self._get_gate_active_probs().detach()

sphinx/source/binary_concrete_stg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ BinaryConcreteStochasticGates
33

44
.. autoclass:: captum.module.BinaryConcreteStochasticGates
55
:members:
6+
:inherited-members: Module

sphinx/source/gaussian_stg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ GaussianStochasticGates
33

44
.. autoclass:: captum.module.GaussianStochasticGates
55
:members:
6+
:inherited-members: Module

0 commit comments

Comments
 (0)