diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index b34a4d5f4..2d72b732a 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -78,9 +78,7 @@ def __init__( self.reg_reduction = reg_reduction self.n_gates = n_gates - self.register_buffer( - "mask", mask.detach().clone() if mask is not None else None - ) + self.mask = mask self.reg_weight = reg_weight def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]: