diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 6d088ef9b..30af2af52 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -729,6 +729,7 @@ def attribute_future( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, + enable_cross_tensor_attribution: bool = False, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" @@ -743,17 +744,18 @@ def attribute_future( formatted_additional_forward_args = _format_additional_forward_args( additional_forward_args ) - num_examples = formatted_inputs[0].shape[0] formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Perturbations per evaluation must be an integer and at least 1." with torch.no_grad(): + attr_progress = None if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, + enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -768,7 +770,7 @@ def attribute_future( formatted_additional_forward_args, ) - if show_progress: + if attr_progress is not None: attr_progress.update() processed_initial_eval_fut: Optional[ @@ -788,101 +790,136 @@ def attribute_future( ) ) - # The will be the same amount futures as modified_eval down there, - # since we cannot add up the evaluation result adhoc under async mode. - all_modified_eval_futures: List[ - List[Future[Tuple[List[Tensor], List[Tensor]]]] - ] = [[] for _ in range(len(inputs))] - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - # Skip any empty input tensors - if torch.numel(formatted_inputs[i]) == 0: - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, + if enable_cross_tensor_attribution: + raise NotImplementedError("Not supported yet") + else: + # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric + # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got + # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]` + return self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long formatted_inputs, formatted_additional_forward_args, target, baselines, formatted_feature_mask, perturbations_per_eval, + attr_progress, + processed_initial_eval_fut, + is_inputs_tuple, **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) + ) - if show_progress: - attr_progress.update() + def _attribute_with_independent_feature_masks_future( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + perturbations_per_eval: int, + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + processed_initial_eval_fut: Future[ + Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] + ], + is_inputs_tuple: bool, + **kwargs: Any, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + num_examples = formatted_inputs[0].shape[0] + # The will be the same amount futures as modified_eval down there, + # since we cannot add up the evaluation result adhoc under async mode. + all_modified_eval_futures: List[ + List[Future[Tuple[List[Tensor], List[Tensor]]]] + ] = [[] for _ in range(len(formatted_inputs))] + # Iterate through each feature tensor for ablation + for i in range(len(formatted_inputs)): + # Skip any empty input tensors + if torch.numel(formatted_inputs[i]) == 0: + continue - if not isinstance(modified_eval, torch.Future): - raise AssertionError( - "when using attribute_future, modified_eval should have " - f"Future type rather than {type(modified_eval)}" - ) - if processed_initial_eval_fut is None: - raise AssertionError( - "processed_initial_eval_fut should not be None" - ) + for ( + current_inputs, + current_add_args, + current_target, + current_mask, + ) in self._ith_input_ablation_generator( + i, + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + perturbations_per_eval, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) - # Need to collect both initial eval and modified_eval - eval_futs: Future[ - List[ - Future[ - Union[ - Tuple[ - List[Tensor], - List[Tensor], - Tensor, - Tensor, - int, - dtype, - ], + if attr_progress is not None: + attr_progress.update() + + if not isinstance(modified_eval, torch.Future): + raise AssertionError( + "when using attribute_future, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + if processed_initial_eval_fut is None: + raise AssertionError( + "processed_initial_eval_fut should not be None" + ) + + # Need to collect both initial eval and modified_eval + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[ + List[Tensor], + List[Tensor], + Tensor, Tensor, - ] + int, + dtype, + ], + Tensor, ] ] - ] = collect_all( - [ - processed_initial_eval_fut, - modified_eval, - ] - ) + ] + ] = collect_all( + [ + processed_initial_eval_fut, + modified_eval, + ] + ) - ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( - eval_futs.then( - lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long - eval_futs=eval_futs, - current_inputs=current_inputs, - current_mask=current_mask, - i=i, - perturbations_per_eval=perturbations_per_eval, - num_examples=num_examples, - formatted_inputs=formatted_inputs, - ) + ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( + eval_futs.then( + lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long + eval_futs=eval_futs, + current_inputs=current_inputs, + current_mask=current_mask, + i=i, + perturbations_per_eval=perturbations_per_eval, + num_examples=num_examples, + formatted_inputs=formatted_inputs, ) ) + ) - all_modified_eval_futures[i].append(ablated_out_fut) + all_modified_eval_futures[i].append(ablated_out_fut) - if show_progress: - attr_progress.close() + if attr_progress is not None: + attr_progress.close() - return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long # pyre-fixme[3] return type must be annotated def _attribute_progress_setup(