Skip to content

Commit 761a219

Browse files
aobo-yfacebook-github-bot
authored andcommitted
validate forward_fun output shape in FeatureAblation (#1091)
Summary: Correctly validate the `forward_fun`'s output shape in `FeatureAblation` (and `FeaturePermutation`) Abandoned previous flawed assumption of "aggregation mode", which forbid the support for multi-outputs models (ref #1047) New logic does not care output shape when `perturbations_per_eval == 1`. Only when `perturbations_per_eval > 1`, it require "Non-Aggregation mode", which is defined as the 1st dim of the model's output should grow with the input's batch size in the same ratio. This is achieved by actually comparing the output shape of 2 different inputs instead of making any assumption based on other user config: - The baseline output is from the initial eval with the original inputs which we have to run anyway. - The expanded output is from the 1st ablated eval whose input batch size has been expanded for more feature perturbation. This way does not even introduce any extra forward calls. Pull Request resolved: #1091 Reviewed By: vivekmig Differential Revision: D42027843 Pulled By: aobo-y fbshipit-source-id: cbcbad64bb1695e7be9c9447ebde6ec3f0cb8a90
1 parent ecc81e6 commit 761a219

File tree

2 files changed

+42
-54
lines changed

2 files changed

+42
-54
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def __init__(self, forward_func: Callable) -> None:
5353
PerturbationAttribution.__init__(self, forward_func)
5454
self.use_weights = False
5555

56+
# only used when perturbations_per_eval > 1, where the 1st dim of forward_func's
57+
# output must grow as the input batch size. If forward's output is aggregated,
58+
# we cannot expand the input to include more perturbations in one call.
59+
# If it's False, we will force the validation by comparing the outpus of
60+
# the original input and the modified input whose batch size expanded based on
61+
# perturbations_per_eval. Set the flag to True if the output of the modified
62+
# input grow as expected. Once it turns to True, we will assume the model's
63+
# behavior stays consistent and no longer check again
64+
self._is_output_shape_valid = False
65+
5666
@log_usage()
5767
def attribute(
5868
self,
@@ -291,21 +301,10 @@ def attribute(
291301

292302
# flatten eval outputs into 1D (n_outputs)
293303
# add the leading dim for n_feature_perturbed
294-
initial_eval = initial_eval.reshape(1, -1)
295-
296-
agg_output_mode = FeatureAblation._find_output_mode(
297-
perturbations_per_eval, feature_mask
298-
)
299-
300-
if not agg_output_mode:
301-
assert n_outputs == num_examples, (
302-
"expected output of `forward_func` to have "
303-
+ "`batch_size` elements for perturbations_per_eval > 1 "
304-
+ "and all feature_mask.shape[0] > 1"
305-
)
304+
flattened_initial_eval = initial_eval.reshape(1, -1)
306305

307306
# Initialize attribution totals and counts
308-
attrib_type = cast(dtype, initial_eval.dtype)
307+
attrib_type = cast(dtype, flattened_initial_eval.dtype)
309308

310309
total_attrib = [
311310
# attribute w.r.t each output element
@@ -362,21 +361,43 @@ def attribute(
362361
if show_progress:
363362
attr_progress.update()
364363

365-
if not agg_output_mode:
366-
# current_batch_size is not n_examples
367-
# it may get expanded by n_feature_perturbed
364+
# if perturbations_per_eval > 1, the output shape must grow with
365+
# input and not be aggregated
366+
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
368367
current_batch_size = current_inputs[0].shape[0]
368+
369+
# number of perturbation, which is not the same as
370+
# perturbations_per_eval when not enough features to perturb
371+
n_perturb = current_batch_size / num_examples
372+
373+
current_output_shape = modified_eval.shape
374+
375+
# use initial_eval as the forward of perturbations_per_eval = 1
376+
initial_output_shape = initial_eval.shape
377+
369378
assert (
370-
modified_eval.numel() == current_batch_size
371-
), """expected output of forward_func to grow with
372-
batch_size. If this is not the case for your model
373-
please set perturbations_per_eval = 1"""
379+
# check if the output is not a scalar
380+
current_output_shape
381+
and initial_output_shape
382+
# check if the output grow in same ratio, i.e., not agg
383+
and current_output_shape[0]
384+
== n_perturb * initial_output_shape[0]
385+
), (
386+
"When perturbations_per_eval > 1, forward_func's output "
387+
"should be a tensor whose 1st dim grow with the input "
388+
f"batch size: when input batch size is {num_examples}, "
389+
f"the output shape is {initial_output_shape}; "
390+
f"when input batch size is {current_batch_size}, "
391+
f"the output shape is {current_output_shape}"
392+
)
393+
394+
self._is_output_shape_valid = True
374395

375396
# reshape the leading dim for n_feature_perturbed
376397
# flatten each feature's eval outputs into 1D of (n_outputs)
377398
modified_eval = modified_eval.reshape(-1, n_outputs)
378399
# eval_diff in shape (n_feature_perturbed, n_outputs)
379-
eval_diff = initial_eval - modified_eval
400+
eval_diff = flattened_initial_eval - modified_eval
380401

381402
# append the shape of one input example
382403
# to make it broadcastable to mask
@@ -572,28 +593,6 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs):
572593
for inp, mask in zip(inputs, feature_mask)
573594
)
574595

575-
@staticmethod
576-
def _find_output_mode(
577-
perturbations_per_eval: int,
578-
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
579-
) -> bool:
580-
"""
581-
Returns True if the output mode is "aggregation output mode"
582-
583-
Aggregation output mode is defined as: when there is no 1:1 correspondence
584-
with the `num_examples` (`batch_size`) and the amount of outputs your model
585-
produces, i.e. the model output does not grow in size as the input becomes
586-
larger.
587-
588-
We assume this is the case if `perturbations_per_eval == 1`
589-
and your feature mask is None or is associated to all
590-
examples in a batch (fm.shape[0] == 1 for all fm in feature_mask).
591-
"""
592-
return perturbations_per_eval == 1 and (
593-
feature_mask is None
594-
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
595-
)
596-
597596
def _strict_run_forward(self, *args, **kwargs) -> Tensor:
598597
"""
599598
A temp wrapper for global _run_forward util to force forward output

tests/attr/test_feature_ablation.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -345,17 +345,6 @@ def forward_func(inp):
345345
with self.assertRaises(AssertionError):
346346
_ = ablation.attribute(inp, perturbations_per_eval=2)
347347

348-
def test_error_agg_mode_incorrect_fm(self) -> None:
349-
def forward_func(inp):
350-
return inp[0].unsqueeze(0)
351-
352-
inp = torch.tensor([[1, 2, 3], [4, 5, 6]])
353-
mask = torch.tensor([[0, 1, 2], [0, 0, 1]])
354-
355-
ablation = FeatureAblation(forward_func)
356-
with self.assertRaises(AssertionError):
357-
_ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask)
358-
359348
def test_empty_sparse_features(self) -> None:
360349
ablation_algo = FeatureAblation(BasicModelWithSparseInputs())
361350
inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])

0 commit comments

Comments
 (0)