From 1850394ab71080efb680963b7ddc986e0cda3bf6 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 12 Jun 2025 18:00:48 -0700 Subject: [PATCH] tmp Summary: do not commit Rollback Plan: Differential Revision: D76560862 --- captum/module/stochastic_gates_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index b34a4d5f4d..2d72b732a2 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -78,9 +78,7 @@ def __init__( self.reg_reduction = reg_reduction self.n_gates = n_gates - self.register_buffer( - "mask", mask.detach().clone() if mask is not None else None - ) + self.mask = mask self.reg_weight = reg_weight def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]: