From a2bb059e091c1b52f1d19bb27f2fd0fa1ade06d1 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 3 Jul 2025 17:23:50 -0700 Subject: [PATCH] Update occlusion to new method of constructing ablated batches (#1616) Summary: `FeaturePermutation` and `Occlusion` are subclasses of `FeatureAblation`, which contains the bulk of the logic around iterating through and processing perturbed inputs. Previously, `FeatureAblation` only constructed the perturbed input by looking at each input tensor individually, as there wasn't an explicit use-case that needed cross-tensor feature grouping. However this behavior has been modified to support cross-tensor masking, as different sparse features at Meta are represented by different tensors when they're finally passed to Captum, and we need to support grouping across features and feature types for various Ads workstreams/asks. The new behavior is mostly rolled out internally, and is controlled by the `enable_cross_tensor_attribution` flag. We do not want to support both behaviors indefinitely. `Occlusion` does not use the custom masks parameter exposed via `FeatureAblation`, but constructs the masks internally. There's no use-case requiring cross-tensor masking for occlusion, so it's not a requirement that it still follows the logic in `FeatureAblation`. However, it's been adapted in this diff in order to just reuse the code from `FeatureAblation`, as the logic specific to `enable_cross_tensor_attribution=False` will be going away. Reviewed By: cyrjano Differential Revision: D76483214 --- captum/_utils/common.py | 16 ---- captum/attr/_core/feature_ablation.py | 27 +++++-- captum/attr/_core/feature_permutation.py | 1 + captum/attr/_core/occlusion.py | 97 +++++++++++++++++++++++- tests/attr/test_occlusion.py | 25 ++++++ 5 files changed, 142 insertions(+), 24 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index b62c71dc6..5b83b0be2 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -908,22 +908,6 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int: return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel())) -def _get_feature_idx_to_tensor_idx( - formatted_feature_mask: Tuple[Tensor, ...], -) -> Dict[int, List[int]]: - """ - For a given tuple of tensors, return dict of tensor values to list of tensor indices - they appear in. - """ - feature_idx_to_tensor_idx: Dict[int, List[int]] = {} - for i, mask in enumerate(formatted_feature_mask): - for feature_idx in torch.unique(mask): - if feature_idx.item() not in feature_idx_to_tensor_idx: - feature_idx_to_tensor_idx[feature_idx.item()] = [] - feature_idx_to_tensor_idx[feature_idx.item()].append(i) - return feature_idx_to_tensor_idx - - def _maybe_expand_parameters( perturbations_per_eval: int, formatted_inputs: Tuple[Tensor, ...], diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 746bf7cc4..06190af6e 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -24,7 +24,6 @@ _format_additional_forward_args, _format_feature_mask, _format_output, - _get_feature_idx_to_tensor_idx, _is_tuple, _maybe_expand_parameters, _run_forward, @@ -507,8 +506,8 @@ def _attribute_with_cross_tensor_feature_masks( perturbations_per_eval: int, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: - feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx( - formatted_feature_mask + feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx( + formatted_feature_mask, **kwargs ) all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) @@ -575,6 +574,7 @@ def _attribute_with_cross_tensor_feature_masks( current_feature_idxs, feature_idx_to_tensor_idx, current_num_ablated_features, + **kwargs, ) ) @@ -613,6 +613,21 @@ def _attribute_with_cross_tensor_feature_masks( ) return total_attrib, weights + def _get_feature_idx_to_tensor_idx( + self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any + ) -> Dict[int, List[int]]: + """ + For a given tuple of tensors, return dict of tensor values to list of tensor + indices they appear in. + """ + feature_idx_to_tensor_idx: Dict[int, List[int]] = {} + for i, mask in enumerate(formatted_feature_mask): + for feature_idx in torch.unique(mask): + if feature_idx.item() not in feature_idx_to_tensor_idx: + feature_idx_to_tensor_idx[feature_idx.item()] = [] + feature_idx_to_tensor_idx[feature_idx.item()].append(i) + return feature_idx_to_tensor_idx + def _should_skip_inputs_and_warn( self, current_feature_idxs: List[int], @@ -656,6 +671,7 @@ def _construct_ablated_input_across_tensors( feature_idxs: List[int], feature_idx_to_tensor_idx: Dict[int, List[int]], current_num_ablated_features: int, + **kwargs: Any, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: ablated_inputs = [] current_masks: List[Optional[Tensor]] = [] @@ -946,8 +962,8 @@ def _attribute_with_cross_tensor_feature_masks_future( perturbations_per_eval: int, **kwargs: Any, ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: - feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx( - formatted_feature_mask + feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx( + formatted_feature_mask, **kwargs ) all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) @@ -1016,6 +1032,7 @@ def _attribute_with_cross_tensor_feature_masks_future( current_feature_idxs, feature_idx_to_tensor_idx, current_num_ablated_features, + **kwargs, ) ) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index efad78e02..b9630ab73 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -377,6 +377,7 @@ def _construct_ablated_input_across_tensors( feature_idxs: List[int], feature_idx_to_tensor_idx: Dict[int, List[int]], current_num_ablated_features: int, + **kwargs: Any, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: current_masks: List[Optional[Tensor]] = [] tensor_idxs = { diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 1b6d6e2bf..ccd752e64 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -267,6 +267,7 @@ def attribute( # type: ignore shift_counts=tuple(shift_counts), strides=strides, show_progress=show_progress, + enable_cross_tensor_attribution=True, ) def attribute_future(self) -> None: @@ -310,6 +311,7 @@ def _construct_ablated_input( kwargs["sliding_window_tensors"], kwargs["strides"], kwargs["shift_counts"], + is_expanded_input=True, ) for j in range(start_feature, end_feature) ], @@ -327,11 +329,12 @@ def _construct_ablated_input( def _occlusion_mask( self, - expanded_input: Tensor, + input: Tensor, ablated_feature_num: int, sliding_window_tsr: Tensor, strides: Union[int, Tuple[int, ...]], shift_counts: Tuple[int, ...], + is_expanded_input: bool, ) -> Tensor: """ This constructs the current occlusion mask, which is the appropriate @@ -365,8 +368,9 @@ def _occlusion_mask( current_index.append((remaining_total % shift_count) * stride) remaining_total = remaining_total // shift_count + dim = 2 if is_expanded_input else 1 remaining_padding = np.subtract( - expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape) + input.shape[dim:], np.add(current_index, sliding_window_tsr.shape) ) pad_values = [ val for pair in zip(remaining_padding, current_index) for val in pair @@ -391,3 +395,90 @@ def _get_feature_counts( ) -> Tuple[int, ...]: """return the numbers of possible input features""" return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) + + def _get_feature_idx_to_tensor_idx( + self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any + ) -> Dict[int, List[int]]: + feature_idx_to_tensor_idx = {} + curr_feature_idx = 0 + for i, shift_count in enumerate(kwargs["shift_counts"]): + num_features = int(np.prod(shift_count)) + for _ in range(num_features): + feature_idx_to_tensor_idx[curr_feature_idx] = [i] + curr_feature_idx += 1 + return feature_idx_to_tensor_idx + + def _get_accumulated_shift_count_products( + self, + shift_counts: Tuple[int, ...], + ) -> List[int]: + shift_count_prod = [np.prod(counts).astype(int) for counts in shift_counts] + curr_prod = 1 + acc_prod = [0] + for i in range(1, len(shift_count_prod)): + curr_prod *= shift_count_prod[i - 1] + acc_prod.append(curr_prod) + return acc_prod + + def _construct_ablated_input_across_tensors( + self, + inputs: Tuple[Tensor, ...], + input_mask: Tuple[Tensor, ...], + baselines: BaselineType, + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, + **kwargs: Any, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: + ablated_inputs = [] + current_masks: List[Optional[Tensor]] = [] + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + + accumulated_shift_count_prods = self._get_accumulated_shift_count_products( + kwargs["shift_counts"] + ) + for i, input_tensor in enumerate(inputs): + if i not in tensor_idxs: + ablated_inputs.append(input_tensor) + current_masks.append(None) + continue + tensor_mask: List[Tensor] = [] + baseline = baselines[i] if isinstance(baselines, tuple) else baselines + for feature_idx in feature_idxs: + + if feature_idx_to_tensor_idx[feature_idx][0] != i: + tensor_mask.append( + torch.zeros((1,) + tuple(input_tensor.shape[1:])) + ) + continue + ablated_feature_num = feature_idx - accumulated_shift_count_prods[i] + mask = self._occlusion_mask( + input_tensor, + ablated_feature_num, + kwargs["sliding_window_tensors"][i], + kwargs["strides"][i], + kwargs["shift_counts"][i], + is_expanded_input=False, + ) + tensor_mask.append(mask) + assert baseline is not None, "baseline must be provided" + current_mask = torch.stack(tensor_mask, dim=0) + current_masks.append(current_mask) + ablated_input = input_tensor.clone().reshape( + (current_num_ablated_features, -1) + tuple(input_tensor.shape[1:]) + ) + ablated_input = ( + ablated_input + * ( + torch.ones(1, dtype=torch.long, device=input_tensor.device) + - current_mask + ).to(input_tensor.dtype) + ) + (baseline * current_mask.to(input_tensor.dtype)) + ablated_inputs.append(ablated_input.reshape(input_tensor.shape)) + return tuple(ablated_inputs), tuple(current_masks) diff --git a/tests/attr/test_occlusion.py b/tests/attr/test_occlusion.py index f4a884fdd..cd5d77887 100644 --- a/tests/attr/test_occlusion.py +++ b/tests/attr/test_occlusion.py @@ -282,6 +282,31 @@ def test_simple_multi_input_conv(self) -> None: strides=((1, 2, 1), (1, 1, 2)), ) + def test_simple_multi_input_conv_diff_shapes(self) -> None: + net = BasicModel_ConvNet_One_Conv() + inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) + inp2 = torch.ones((1, 1, 4, 1)) + self._occlusion_test_assert( + net, + (inp, inp2), + ( + [ + [ + [ + [17.0, 17.0, 17.0, 17.0], + [17.0, 17.0, 17.0, 17.0], + [67.0, 67.0, 67.0, 67.0], + [67.0, 67.0, 67.0, 67.0], + ] + ] + ], + [[[[10.0], [10.0], [8.0], [6.0]]]], + ), + perturbations_per_eval=(1, 2, 4, 8, 12, 16), + sliding_window_shapes=((1, 2, 4), (1, 2, 1)), + strides=((1, 2, 1), (1, 1, 1)), + ) + def test_futures_not_implemented(self) -> None: net = BasicModel_ConvNet_One_Conv() occ = Occlusion(net)