Skip to content

Update occlusion to new method of constructing ablated batches #1616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,22 +908,6 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel()))


def _get_feature_idx_to_tensor_idx(
formatted_feature_mask: Tuple[Tensor, ...],
) -> Dict[int, List[int]]:
"""
For a given tuple of tensors, return dict of tensor values to list of tensor indices
they appear in.
"""
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
for i, mask in enumerate(formatted_feature_mask):
for feature_idx in torch.unique(mask):
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
return feature_idx_to_tensor_idx


def _maybe_expand_parameters(
perturbations_per_eval: int,
formatted_inputs: Tuple[Tensor, ...],
Expand Down
27 changes: 22 additions & 5 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
_format_additional_forward_args,
_format_feature_mask,
_format_output,
_get_feature_idx_to_tensor_idx,
_is_tuple,
_maybe_expand_parameters,
_run_forward,
Expand Down Expand Up @@ -507,8 +506,8 @@ def _attribute_with_cross_tensor_feature_masks(
perturbations_per_eval: int,
**kwargs: Any,
) -> Tuple[List[Tensor], List[Tensor]]:
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
formatted_feature_mask
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
formatted_feature_mask, **kwargs
)
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())

Expand Down Expand Up @@ -575,6 +574,7 @@ def _attribute_with_cross_tensor_feature_masks(
current_feature_idxs,
feature_idx_to_tensor_idx,
current_num_ablated_features,
**kwargs,
)
)

Expand Down Expand Up @@ -613,6 +613,21 @@ def _attribute_with_cross_tensor_feature_masks(
)
return total_attrib, weights

def _get_feature_idx_to_tensor_idx(
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
) -> Dict[int, List[int]]:
"""
For a given tuple of tensors, return dict of tensor values to list of tensor
indices they appear in.
"""
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
for i, mask in enumerate(formatted_feature_mask):
for feature_idx in torch.unique(mask):
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
return feature_idx_to_tensor_idx

def _should_skip_inputs_and_warn(
self,
current_feature_idxs: List[int],
Expand Down Expand Up @@ -656,6 +671,7 @@ def _construct_ablated_input_across_tensors(
feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
current_num_ablated_features: int,
**kwargs: Any,
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
ablated_inputs = []
current_masks: List[Optional[Tensor]] = []
Expand Down Expand Up @@ -946,8 +962,8 @@ def _attribute_with_cross_tensor_feature_masks_future(
perturbations_per_eval: int,
**kwargs: Any,
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
formatted_feature_mask
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
formatted_feature_mask, **kwargs
)
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())

Expand Down Expand Up @@ -1016,6 +1032,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
current_feature_idxs,
feature_idx_to_tensor_idx,
current_num_ablated_features,
**kwargs,
)
)

Expand Down
1 change: 1 addition & 0 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def _construct_ablated_input_across_tensors(
feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
current_num_ablated_features: int,
**kwargs: Any,
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
current_masks: List[Optional[Tensor]] = []
tensor_idxs = {
Expand Down
97 changes: 94 additions & 3 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -267,6 +267,7 @@ def attribute( # type: ignore
shift_counts=tuple(shift_counts),
strides=strides,
show_progress=show_progress,
enable_cross_tensor_attribution=True,
)

def attribute_future(self) -> None:
Expand Down Expand Up @@ -310,6 +311,7 @@ def _construct_ablated_input(
kwargs["sliding_window_tensors"],
kwargs["strides"],
kwargs["shift_counts"],
is_expanded_input=True,
)
for j in range(start_feature, end_feature)
],
Expand All @@ -327,11 +329,12 @@ def _construct_ablated_input(

def _occlusion_mask(
self,
expanded_input: Tensor,
input: Tensor,
ablated_feature_num: int,
sliding_window_tsr: Tensor,
strides: Union[int, Tuple[int, ...]],
shift_counts: Tuple[int, ...],
is_expanded_input: bool,
) -> Tensor:
"""
This constructs the current occlusion mask, which is the appropriate
Expand Down Expand Up @@ -365,8 +368,9 @@ def _occlusion_mask(
current_index.append((remaining_total % shift_count) * stride)
remaining_total = remaining_total // shift_count

dim = 2 if is_expanded_input else 1
remaining_padding = np.subtract(
expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape)
input.shape[dim:], np.add(current_index, sliding_window_tsr.shape)
)
pad_values = [
val for pair in zip(remaining_padding, current_index) for val in pair
Expand All @@ -391,3 +395,90 @@ def _get_feature_counts(
) -> Tuple[int, ...]:
"""return the numbers of possible input features"""
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])

def _get_feature_idx_to_tensor_idx(
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
) -> Dict[int, List[int]]:
feature_idx_to_tensor_idx = {}
curr_feature_idx = 0
for i, shift_count in enumerate(kwargs["shift_counts"]):
num_features = int(np.prod(shift_count))
for _ in range(num_features):
feature_idx_to_tensor_idx[curr_feature_idx] = [i]
curr_feature_idx += 1
return feature_idx_to_tensor_idx

def _get_accumulated_shift_count_products(
self,
shift_counts: Tuple[int, ...],
) -> List[int]:
shift_count_prod = [np.prod(counts).astype(int) for counts in shift_counts]
curr_prod = 1
acc_prod = [0]
for i in range(1, len(shift_count_prod)):
curr_prod *= shift_count_prod[i - 1]
acc_prod.append(curr_prod)
return acc_prod

def _construct_ablated_input_across_tensors(
self,
inputs: Tuple[Tensor, ...],
input_mask: Tuple[Tensor, ...],
baselines: BaselineType,
feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
current_num_ablated_features: int,
**kwargs: Any,
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
ablated_inputs = []
current_masks: List[Optional[Tensor]] = []
tensor_idxs = {
tensor_idx
for sublist in (
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
)
for tensor_idx in sublist
}

accumulated_shift_count_prods = self._get_accumulated_shift_count_products(
kwargs["shift_counts"]
)
for i, input_tensor in enumerate(inputs):
if i not in tensor_idxs:
ablated_inputs.append(input_tensor)
current_masks.append(None)
continue
tensor_mask: List[Tensor] = []
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
for feature_idx in feature_idxs:

if feature_idx_to_tensor_idx[feature_idx][0] != i:
tensor_mask.append(
torch.zeros((1,) + tuple(input_tensor.shape[1:]))
)
continue
ablated_feature_num = feature_idx - accumulated_shift_count_prods[i]
mask = self._occlusion_mask(
input_tensor,
ablated_feature_num,
kwargs["sliding_window_tensors"][i],
kwargs["strides"][i],
kwargs["shift_counts"][i],
is_expanded_input=False,
)
tensor_mask.append(mask)
assert baseline is not None, "baseline must be provided"
current_mask = torch.stack(tensor_mask, dim=0)
current_masks.append(current_mask)
ablated_input = input_tensor.clone().reshape(
(current_num_ablated_features, -1) + tuple(input_tensor.shape[1:])
)
ablated_input = (
ablated_input
* (
torch.ones(1, dtype=torch.long, device=input_tensor.device)
- current_mask
).to(input_tensor.dtype)
) + (baseline * current_mask.to(input_tensor.dtype))
ablated_inputs.append(ablated_input.reshape(input_tensor.shape))
return tuple(ablated_inputs), tuple(current_masks)
25 changes: 25 additions & 0 deletions tests/attr/test_occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,31 @@ def test_simple_multi_input_conv(self) -> None:
strides=((1, 2, 1), (1, 1, 2)),
)

def test_simple_multi_input_conv_diff_shapes(self) -> None:
net = BasicModel_ConvNet_One_Conv()
inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4)
inp2 = torch.ones((1, 1, 4, 1))
self._occlusion_test_assert(
net,
(inp, inp2),
(
[
[
[
[17.0, 17.0, 17.0, 17.0],
[17.0, 17.0, 17.0, 17.0],
[67.0, 67.0, 67.0, 67.0],
[67.0, 67.0, 67.0, 67.0],
]
]
],
[[[[10.0], [10.0], [8.0], [6.0]]]],
),
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
sliding_window_shapes=((1, 2, 4), (1, 2, 1)),
strides=((1, 2, 1), (1, 1, 1)),
)

def test_futures_not_implemented(self) -> None:
net = BasicModel_ConvNet_One_Conv()
occ = Occlusion(net)
Expand Down