diff --git a/captum/_utils/common.py b/captum/_utils/common.py index b62c71dc6..5b83b0be2 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -908,22 +908,6 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int: return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel())) -def _get_feature_idx_to_tensor_idx( - formatted_feature_mask: Tuple[Tensor, ...], -) -> Dict[int, List[int]]: - """ - For a given tuple of tensors, return dict of tensor values to list of tensor indices - they appear in. - """ - feature_idx_to_tensor_idx: Dict[int, List[int]] = {} - for i, mask in enumerate(formatted_feature_mask): - for feature_idx in torch.unique(mask): - if feature_idx.item() not in feature_idx_to_tensor_idx: - feature_idx_to_tensor_idx[feature_idx.item()] = [] - feature_idx_to_tensor_idx[feature_idx.item()].append(i) - return feature_idx_to_tensor_idx - - def _maybe_expand_parameters( perturbations_per_eval: int, formatted_inputs: Tuple[Tensor, ...], diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 746bf7cc4..06190af6e 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -24,7 +24,6 @@ _format_additional_forward_args, _format_feature_mask, _format_output, - _get_feature_idx_to_tensor_idx, _is_tuple, _maybe_expand_parameters, _run_forward, @@ -507,8 +506,8 @@ def _attribute_with_cross_tensor_feature_masks( perturbations_per_eval: int, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: - feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx( - formatted_feature_mask + feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx( + formatted_feature_mask, **kwargs ) all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) @@ -575,6 +574,7 @@ def _attribute_with_cross_tensor_feature_masks( current_feature_idxs, feature_idx_to_tensor_idx, current_num_ablated_features, + **kwargs, ) ) @@ -613,6 +613,21 @@ def _attribute_with_cross_tensor_feature_masks( ) return total_attrib, weights + def _get_feature_idx_to_tensor_idx( + self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any + ) -> Dict[int, List[int]]: + """ + For a given tuple of tensors, return dict of tensor values to list of tensor + indices they appear in. + """ + feature_idx_to_tensor_idx: Dict[int, List[int]] = {} + for i, mask in enumerate(formatted_feature_mask): + for feature_idx in torch.unique(mask): + if feature_idx.item() not in feature_idx_to_tensor_idx: + feature_idx_to_tensor_idx[feature_idx.item()] = [] + feature_idx_to_tensor_idx[feature_idx.item()].append(i) + return feature_idx_to_tensor_idx + def _should_skip_inputs_and_warn( self, current_feature_idxs: List[int], @@ -656,6 +671,7 @@ def _construct_ablated_input_across_tensors( feature_idxs: List[int], feature_idx_to_tensor_idx: Dict[int, List[int]], current_num_ablated_features: int, + **kwargs: Any, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: ablated_inputs = [] current_masks: List[Optional[Tensor]] = [] @@ -946,8 +962,8 @@ def _attribute_with_cross_tensor_feature_masks_future( perturbations_per_eval: int, **kwargs: Any, ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: - feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx( - formatted_feature_mask + feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx( + formatted_feature_mask, **kwargs ) all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) @@ -1016,6 +1032,7 @@ def _attribute_with_cross_tensor_feature_masks_future( current_feature_idxs, feature_idx_to_tensor_idx, current_num_ablated_features, + **kwargs, ) ) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index efad78e02..b9630ab73 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -377,6 +377,7 @@ def _construct_ablated_input_across_tensors( feature_idxs: List[int], feature_idx_to_tensor_idx: Dict[int, List[int]], current_num_ablated_features: int, + **kwargs: Any, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: current_masks: List[Optional[Tensor]] = [] tensor_idxs = { diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 1b6d6e2bf..ccd752e64 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -267,6 +267,7 @@ def attribute( # type: ignore shift_counts=tuple(shift_counts), strides=strides, show_progress=show_progress, + enable_cross_tensor_attribution=True, ) def attribute_future(self) -> None: @@ -310,6 +311,7 @@ def _construct_ablated_input( kwargs["sliding_window_tensors"], kwargs["strides"], kwargs["shift_counts"], + is_expanded_input=True, ) for j in range(start_feature, end_feature) ], @@ -327,11 +329,12 @@ def _construct_ablated_input( def _occlusion_mask( self, - expanded_input: Tensor, + input: Tensor, ablated_feature_num: int, sliding_window_tsr: Tensor, strides: Union[int, Tuple[int, ...]], shift_counts: Tuple[int, ...], + is_expanded_input: bool, ) -> Tensor: """ This constructs the current occlusion mask, which is the appropriate @@ -365,8 +368,9 @@ def _occlusion_mask( current_index.append((remaining_total % shift_count) * stride) remaining_total = remaining_total // shift_count + dim = 2 if is_expanded_input else 1 remaining_padding = np.subtract( - expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape) + input.shape[dim:], np.add(current_index, sliding_window_tsr.shape) ) pad_values = [ val for pair in zip(remaining_padding, current_index) for val in pair @@ -391,3 +395,90 @@ def _get_feature_counts( ) -> Tuple[int, ...]: """return the numbers of possible input features""" return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) + + def _get_feature_idx_to_tensor_idx( + self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any + ) -> Dict[int, List[int]]: + feature_idx_to_tensor_idx = {} + curr_feature_idx = 0 + for i, shift_count in enumerate(kwargs["shift_counts"]): + num_features = int(np.prod(shift_count)) + for _ in range(num_features): + feature_idx_to_tensor_idx[curr_feature_idx] = [i] + curr_feature_idx += 1 + return feature_idx_to_tensor_idx + + def _get_accumulated_shift_count_products( + self, + shift_counts: Tuple[int, ...], + ) -> List[int]: + shift_count_prod = [np.prod(counts).astype(int) for counts in shift_counts] + curr_prod = 1 + acc_prod = [0] + for i in range(1, len(shift_count_prod)): + curr_prod *= shift_count_prod[i - 1] + acc_prod.append(curr_prod) + return acc_prod + + def _construct_ablated_input_across_tensors( + self, + inputs: Tuple[Tensor, ...], + input_mask: Tuple[Tensor, ...], + baselines: BaselineType, + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, + **kwargs: Any, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: + ablated_inputs = [] + current_masks: List[Optional[Tensor]] = [] + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + + accumulated_shift_count_prods = self._get_accumulated_shift_count_products( + kwargs["shift_counts"] + ) + for i, input_tensor in enumerate(inputs): + if i not in tensor_idxs: + ablated_inputs.append(input_tensor) + current_masks.append(None) + continue + tensor_mask: List[Tensor] = [] + baseline = baselines[i] if isinstance(baselines, tuple) else baselines + for feature_idx in feature_idxs: + + if feature_idx_to_tensor_idx[feature_idx][0] != i: + tensor_mask.append( + torch.zeros((1,) + tuple(input_tensor.shape[1:])) + ) + continue + ablated_feature_num = feature_idx - accumulated_shift_count_prods[i] + mask = self._occlusion_mask( + input_tensor, + ablated_feature_num, + kwargs["sliding_window_tensors"][i], + kwargs["strides"][i], + kwargs["shift_counts"][i], + is_expanded_input=False, + ) + tensor_mask.append(mask) + assert baseline is not None, "baseline must be provided" + current_mask = torch.stack(tensor_mask, dim=0) + current_masks.append(current_mask) + ablated_input = input_tensor.clone().reshape( + (current_num_ablated_features, -1) + tuple(input_tensor.shape[1:]) + ) + ablated_input = ( + ablated_input + * ( + torch.ones(1, dtype=torch.long, device=input_tensor.device) + - current_mask + ).to(input_tensor.dtype) + ) + (baseline * current_mask.to(input_tensor.dtype)) + ablated_inputs.append(ablated_input.reshape(input_tensor.shape)) + return tuple(ablated_inputs), tuple(current_masks) diff --git a/tests/attr/test_occlusion.py b/tests/attr/test_occlusion.py index f4a884fdd..cd5d77887 100644 --- a/tests/attr/test_occlusion.py +++ b/tests/attr/test_occlusion.py @@ -282,6 +282,31 @@ def test_simple_multi_input_conv(self) -> None: strides=((1, 2, 1), (1, 1, 2)), ) + def test_simple_multi_input_conv_diff_shapes(self) -> None: + net = BasicModel_ConvNet_One_Conv() + inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) + inp2 = torch.ones((1, 1, 4, 1)) + self._occlusion_test_assert( + net, + (inp, inp2), + ( + [ + [ + [ + [17.0, 17.0, 17.0, 17.0], + [17.0, 17.0, 17.0, 17.0], + [67.0, 67.0, 67.0, 67.0], + [67.0, 67.0, 67.0, 67.0], + ] + ] + ], + [[[[10.0], [10.0], [8.0], [6.0]]]], + ), + perturbations_per_eval=(1, 2, 4, 8, 12, 16), + sliding_window_shapes=((1, 2, 4), (1, 2, 1)), + strides=((1, 2, 1), (1, 1, 1)), + ) + def test_futures_not_implemented(self) -> None: net = BasicModel_ConvNet_One_Conv() occ = Occlusion(net)