Skip to content

Commit 0fea20c

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Update occlusion to new method of constructing ablated batches (#1616)
Summary: `FeaturePermutation` and `Occlusion` are subclasses of `FeatureAblation`, which contains the bulk of the logic around iterating through and processing perturbed inputs. Previously, `FeatureAblation` only constructed the perturbed input by looking at each input tensor individually, as there wasn't an explicit use-case that needed cross-tensor feature grouping. However this behavior has been modified to support cross-tensor masking, as different sparse features at Meta are represented by different tensors when they're finally passed to Captum, and we need to support grouping across features and feature types for various Ads workstreams/asks. The new behavior is mostly rolled out internally, and is controlled by the `enable_cross_tensor_attribution` flag. We do not want to support both behaviors indefinitely. `Occlusion` does not use the custom masks parameter exposed via `FeatureAblation`, but constructs the masks internally. There's no use-case requiring cross-tensor masking for occlusion, so it's not a requirement that it still follows the logic in `FeatureAblation`. However, it's been adapted in this diff in order to just reuse the code from `FeatureAblation`, as the logic specific to `enable_cross_tensor_attribution=False` will be going away. Reviewed By: cyrjano Differential Revision: D76483214
1 parent 4be90f2 commit 0fea20c

File tree

5 files changed

+152
-24
lines changed

5 files changed

+152
-24
lines changed

captum/_utils/common.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -908,22 +908,6 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
908908
return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel()))
909909

910910

911-
def _get_feature_idx_to_tensor_idx(
912-
formatted_feature_mask: Tuple[Tensor, ...],
913-
) -> Dict[int, List[int]]:
914-
"""
915-
For a given tuple of tensors, return dict of tensor values to list of tensor indices
916-
they appear in.
917-
"""
918-
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
919-
for i, mask in enumerate(formatted_feature_mask):
920-
for feature_idx in torch.unique(mask):
921-
if feature_idx.item() not in feature_idx_to_tensor_idx:
922-
feature_idx_to_tensor_idx[feature_idx.item()] = []
923-
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
924-
return feature_idx_to_tensor_idx
925-
926-
927911
def _maybe_expand_parameters(
928912
perturbations_per_eval: int,
929913
formatted_inputs: Tuple[Tensor, ...],

captum/attr/_core/feature_ablation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
_format_additional_forward_args,
2525
_format_feature_mask,
2626
_format_output,
27-
_get_feature_idx_to_tensor_idx,
2827
_is_tuple,
2928
_maybe_expand_parameters,
3029
_run_forward,
@@ -507,8 +506,8 @@ def _attribute_with_cross_tensor_feature_masks(
507506
perturbations_per_eval: int,
508507
**kwargs: Any,
509508
) -> Tuple[List[Tensor], List[Tensor]]:
510-
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
511-
formatted_feature_mask
509+
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
510+
formatted_feature_mask, **kwargs
512511
)
513512
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
514513

@@ -575,6 +574,7 @@ def _attribute_with_cross_tensor_feature_masks(
575574
current_feature_idxs,
576575
feature_idx_to_tensor_idx,
577576
current_num_ablated_features,
577+
**kwargs,
578578
)
579579
)
580580

@@ -613,6 +613,21 @@ def _attribute_with_cross_tensor_feature_masks(
613613
)
614614
return total_attrib, weights
615615

616+
def _get_feature_idx_to_tensor_idx(
617+
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
618+
) -> Dict[int, List[int]]:
619+
"""
620+
For a given tuple of tensors, return dict of tensor values to list of tensor
621+
indices they appear in.
622+
"""
623+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
624+
for i, mask in enumerate(formatted_feature_mask):
625+
for feature_idx in torch.unique(mask):
626+
if feature_idx.item() not in feature_idx_to_tensor_idx:
627+
feature_idx_to_tensor_idx[feature_idx.item()] = []
628+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
629+
return feature_idx_to_tensor_idx
630+
616631
def _should_skip_inputs_and_warn(
617632
self,
618633
current_feature_idxs: List[int],
@@ -656,6 +671,7 @@ def _construct_ablated_input_across_tensors(
656671
feature_idxs: List[int],
657672
feature_idx_to_tensor_idx: Dict[int, List[int]],
658673
current_num_ablated_features: int,
674+
**kwargs: Any,
659675
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
660676
ablated_inputs = []
661677
current_masks: List[Optional[Tensor]] = []
@@ -946,8 +962,8 @@ def _attribute_with_cross_tensor_feature_masks_future(
946962
perturbations_per_eval: int,
947963
**kwargs: Any,
948964
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
949-
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
950-
formatted_feature_mask
965+
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
966+
formatted_feature_mask, **kwargs
951967
)
952968
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
953969

@@ -1016,6 +1032,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
10161032
current_feature_idxs,
10171033
feature_idx_to_tensor_idx,
10181034
current_num_ablated_features,
1035+
**kwargs,
10191036
)
10201037
)
10211038

captum/attr/_core/feature_permutation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def _construct_ablated_input_across_tensors(
377377
feature_idxs: List[int],
378378
feature_idx_to_tensor_idx: Dict[int, List[int]],
379379
current_num_ablated_features: int,
380+
**kwargs: Any,
380381
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
381382
current_masks: List[Optional[Tensor]] = []
382383
tensor_idxs = {

captum/attr/_core/occlusion.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -267,6 +267,7 @@ def attribute( # type: ignore
267267
shift_counts=tuple(shift_counts),
268268
strides=strides,
269269
show_progress=show_progress,
270+
enable_cross_tensor_attribution=True,
270271
)
271272

272273
def attribute_future(self) -> None:
@@ -310,6 +311,7 @@ def _construct_ablated_input(
310311
kwargs["sliding_window_tensors"],
311312
kwargs["strides"],
312313
kwargs["shift_counts"],
314+
is_expanded_input=True,
313315
)
314316
for j in range(start_feature, end_feature)
315317
],
@@ -327,11 +329,12 @@ def _construct_ablated_input(
327329

328330
def _occlusion_mask(
329331
self,
330-
expanded_input: Tensor,
332+
input: Tensor,
331333
ablated_feature_num: int,
332334
sliding_window_tsr: Tensor,
333335
strides: Union[int, Tuple[int, ...]],
334336
shift_counts: Tuple[int, ...],
337+
is_expanded_input: bool,
335338
) -> Tensor:
336339
"""
337340
This constructs the current occlusion mask, which is the appropriate
@@ -365,8 +368,9 @@ def _occlusion_mask(
365368
current_index.append((remaining_total % shift_count) * stride)
366369
remaining_total = remaining_total // shift_count
367370

371+
dim = 2 if is_expanded_input else 1
368372
remaining_padding = np.subtract(
369-
expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape)
373+
input.shape[dim:], np.add(current_index, sliding_window_tsr.shape)
370374
)
371375
pad_values = [
372376
val for pair in zip(remaining_padding, current_index) for val in pair
@@ -391,3 +395,88 @@ def _get_feature_counts(
391395
) -> Tuple[int, ...]:
392396
"""return the numbers of possible input features"""
393397
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])
398+
399+
def _get_feature_idx_to_tensor_idx(
400+
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
401+
) -> Dict[int, List[int]]:
402+
feature_idx_to_tensor_idx = {}
403+
curr_feature_idx = 0
404+
for i, shift_count in enumerate(kwargs["shift_counts"]):
405+
num_features = int(np.prod(shift_count))
406+
for _ in range(num_features):
407+
feature_idx_to_tensor_idx[curr_feature_idx] = [i]
408+
curr_feature_idx += 1
409+
return feature_idx_to_tensor_idx
410+
411+
def _get_accumulated_shift_count_products(
412+
self,
413+
shift_counts: Tuple[int, ...],
414+
) -> List[int]:
415+
shift_count_prod = [np.prod(counts).astype(int) for counts in shift_counts]
416+
curr_prod = 1
417+
acc_prod = [0]
418+
for i in range(1, len(shift_count_prod)):
419+
curr_prod *= shift_count_prod[i - 1]
420+
acc_prod.append(curr_prod)
421+
return acc_prod
422+
423+
def _construct_ablated_input_across_tensors(
424+
self,
425+
inputs: Tuple[Tensor, ...],
426+
input_mask: Tuple[Tensor, ...],
427+
baselines: BaselineType,
428+
feature_idxs: List[int],
429+
feature_idx_to_tensor_idx: Dict[int, List[int]],
430+
current_num_ablated_features: int,
431+
**kwargs: Any,
432+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
433+
ablated_inputs = []
434+
current_masks: List[Optional[Tensor]] = []
435+
tensor_idxs = {
436+
tensor_idx
437+
for sublist in (
438+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
439+
)
440+
for tensor_idx in sublist
441+
}
442+
443+
accumulated_shift_count_prods = self._get_accumulated_shift_count_products(
444+
kwargs["shift_counts"]
445+
)
446+
for i, input_tensor in enumerate(inputs):
447+
if i not in tensor_idxs:
448+
ablated_inputs.append(input_tensor)
449+
current_masks.append(None)
450+
continue
451+
tensor_mask: List[Tensor] = []
452+
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
453+
for feature_idx in feature_idxs:
454+
455+
if feature_idx_to_tensor_idx[feature_idx][0] != i:
456+
tensor_mask.append(torch.zeros((1,) + input_tensor.shape[1:]))
457+
continue
458+
ablated_feature_num = feature_idx - accumulated_shift_count_prods[i]
459+
mask = self._occlusion_mask(
460+
input_tensor,
461+
ablated_feature_num,
462+
kwargs["sliding_window_tensors"][i],
463+
kwargs["strides"][i],
464+
kwargs["shift_counts"][i],
465+
is_expanded_input=False,
466+
)
467+
tensor_mask.append(mask)
468+
assert baseline is not None, "baseline must be provided"
469+
current_mask = torch.stack(tensor_mask, dim=0)
470+
current_masks.append(current_mask)
471+
ablated_input = input_tensor.clone().reshape(
472+
(current_num_ablated_features, -1) + tuple(input_tensor.shape[1:])
473+
)
474+
ablated_input = (
475+
ablated_input
476+
* (
477+
torch.ones(1, dtype=torch.long, device=input_tensor.device)
478+
- current_mask
479+
).to(input_tensor.dtype)
480+
) + (baseline * current_mask.to(input_tensor.dtype))
481+
ablated_inputs.append(ablated_input.reshape(input_tensor.shape))
482+
return tuple(ablated_inputs), tuple(current_masks)

tests/attr/test_occlusion.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,18 @@ def test_simple_multi_input(self) -> None:
144144
sliding_window_shapes=((1,), (1,)),
145145
)
146146

147+
def test_simple_multi_input_diff_shapes(self) -> None:
148+
net = BasicModel3()
149+
inp1 = torch.tensor([[-10.0, 3.0], [8.0, 4.0]])
150+
inp2 = torch.tensor([[-5.0], [1.0]])
151+
self._occlusion_test_assert(
152+
net,
153+
(inp1, inp2),
154+
(torch.tensor([[0.0, 0.0], [6.0, 6.0]]), torch.tensor([[0.0], [-1.0]])),
155+
sliding_window_shapes=((2,), (1,)),
156+
perturbations_per_eval=(2,),
157+
)
158+
147159
def test_simple_multi_input_0d(self) -> None:
148160
net = BasicModel3()
149161
inp1 = torch.tensor([-10.0, 3.0])
@@ -282,6 +294,31 @@ def test_simple_multi_input_conv(self) -> None:
282294
strides=((1, 2, 1), (1, 1, 2)),
283295
)
284296

297+
def test_simple_multi_input_conv_diff_shapes(self) -> None:
298+
net = BasicModel_ConvNet_One_Conv()
299+
inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4)
300+
inp2 = torch.ones((1, 1, 4, 1))
301+
self._occlusion_test_assert(
302+
net,
303+
(inp, inp2),
304+
(
305+
[
306+
[
307+
[
308+
[17.0, 17.0, 17.0, 10.0],
309+
[17.0, 17.0, 17.0, 10.0],
310+
[64.0, 64.0, 64.0, 13.0],
311+
[64.0, 64.0, 64.0, 13.0],
312+
]
313+
]
314+
],
315+
[[[[13.0], [13.0], [13.0], [0.0]]]],
316+
),
317+
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
318+
sliding_window_shapes=((1, 2, 3), (1, 3, 1)),
319+
strides=((1, 2, 3), (1, 3, 1)),
320+
)
321+
285322
def test_futures_not_implemented(self) -> None:
286323
net = BasicModel_ConvNet_One_Conv()
287324
occ = Occlusion(net)

0 commit comments

Comments
 (0)