Description
Motivation
The computation of Concept Activation Vectors (CAVs) is a fundamental component of concept-based explanation methods, most notably Testing with Concept Activation Vectors (TCAV), which is already implemented in Captum. The current standard for computing CAVs in Captum relies on training a Support Vector Machines (SVMs) for each concept. This can present a significant challenge, especially for modern, high-dimensional models. To address this, we propose adding an efficient drop-in with FastCAV. We provide benchmarking and theoretical justification to support its effectiveness.
Captum FastCAV API Design
Background
FastCAV is a novel approach introduced at ICML 2025 based on insights into superposition, that accelerates the computation of CAVs. FastCAV defines the CAV as the vector from the global mean of all activations to the mean of the positive class activations. As such it acts as a drop-in replacement for the prevalent SVM for balanced classes. Additionally, we provide mathematical assumptions under which FastCAV is equivalent to an SVM.
Requires:
-
Model: The trained neural network model
$f$ to be interpreted. -
Layer: A specific layer
$l$ within the model from which to extract activations. -
Concept Dataset
$D_c$ : A set of input examples that visually represent the concept to be analyzed. -
Control Dataset
$D_r$ : A set of random or control examples used as a baseline.
Pseudocode:
def fit(self, x: Tensor, y: Tensor) -> None:
self.mean = x.mean(dim=0)
self.coef_ = (x[y == self.classes_[-1]] -self.mean).mean(dim=0).unsqueeze(0)
self.intercept_ = (-self.coef_ @ self.mean).unsqueeze(1)
Benchmarking
Performance benchmarking:
Given the example in the tutorial tutorials/TCAV_Image.ipynb
we achieve an average speed-up to the DefaultClassifier
of ~8.26 (SVM: 1.09
For bigger models this improves more drastically as shown in Table 1. of the paper.
Visual Benchmarking
Given the example in the tutorial tutorials/TCAV_Image.ipynb
we get the following results for experimental_set_rand
with FastCAV and SVM:
FastCAV:
and for experimental_set_zig_dot
:
Mathematical Assumptions:
Under the given assumptions FastCAV is equivalent to an SVM:
- Gaussian Distribution: The activation vectors for both the random samples and the concept samples are assumed to follow independent multivariate Gaussian distributions.
-
Equal Mixture: The set of concept examples and the set of random examples are of equal size (
$∣Dc∣=∣Dr∣$ ), resulting in a uniform mixture of the two Gaussian distributions. - Isotropic Covariance: The within-class covariance matrices are assumed to be isotropic, meaning they are proportional to the unit matrix. This is a critical assumption that makes the FastCAV solution equivalent to the solution of a Fisher discriminant analysis.
-
High-Dimensionality: The method is analyzed in the context of high-dimensional activation spaces, where the number of dimensions
$d$ is significantly larger than the number of samples$n$ ($d \gg n$ ). In such spaces, the set of support vectors used by an SVM is likely to contain most of the training samples, making the SVM solution approximate the Fisher discriminant solution, and by extension, the FastCAV solution.
Proposed Captum API Design:
The integration of FastCAV into Captum is designed to be a high-performance drop-in to the default DefaultClassifier
which utilizes an SVM. It is exposed to the user through the FastCAVClassifier
class, which leverages a FastCAVLinearModel
internally.
E.g. for the tutorial tutorials/TCAV_Image.ipynb
:
fast_clf = classifier.FastCAVClassifier()
mytcav = TCAV(model=model,
layers=layers,
classifier=fast_clf,
layer_attr_method = LayerIntegratedGradients(
model, None, multiply_by_inputs=False))
FastCAV
This is a low-level utility class, similar in interface to scikit-learn classifiers. It contains the core logic for computing the CAV. It is not a torch.nn.Module
and is not intended for direct use within most Captum workflows, but provides the fundamental algorithm.
Constructor:
FastCAV(**kwargs)
Argument Descriptions:
kwargs
- The constructor currently accepts but ignores any keyword argumentskwargs
to maintain a consistent interface with other classifiers.
Methods:
fit(x: Tensor, y: Tensor)
: Takes tensorsx
of activations and a tensory
of labels and computes thecoef_
andintercept_
.x
: A 2D tensor of shape(n_samples, n_features)
containing the input data (e.g., model activations).y
: A 1D tensor of shape(n_samples,)
containing binary labels (0 or 1).
predict(x: Tensor) -> Tensor
: Predicts class labels for new data points based on the fitted hyperplane.
classes() -> Tensor
: Returns the unique class labels the model was fitted on.
fastcav_train_linear_model
This function serves as the bridge between the FastCAV
logic and Captum's LinearModel
interface. It is designed to be used as the train_fn
for a LinearModel
. It orchestrates the process of fitting a FastCAV
instance and then using its learned parameters to configure a LinearModel
object
.
Signature:
fastcav_train_linear_model(model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], norm_input: bool = False, **fit_kwargs: Any) -> Dict[str, float]
Argument Descriptions:
model
: TheLinearModel
instance to be configured.dataloader
: ADataLoader
providing the training data (activations and labels). Right now this followssklearn_train_linear_model
and iterates through the entire dataloader to collect all data into memory.construct_kwargs
: Keyword arguments passed to the FastCAV constructor. AsFastCAV
does not acceptkwargs
this is to maintain a consistent interface.norm_input
: A boolean indicating whether to normalize the input data before fitting.fit_kwargs
: Additional keyword arguments for the fit method (currently unused by FastCAV).
FastCAVLinearModel
This class acts as a bridge, wrapping the FastCAV
logic into the LinearModel
interface that is standard within Captum's concept utilities. This allows it to be used by higher-level abstractions.
Constructor:
FastCAVLinearModel(**kwargs)
- The constructor simply calls the parent
LinearModel
constructor, passingfastcav_train_linear_model
as thetrain_fn
. Anykwargs
are stored and passed to the training function during thefit
call.
FastCAVClassifier
This is the main user-facing class. It is a drop-in replacement for DefaultClassifier
for users of high-level APIs like TCAV or different future concept based explanations like ACE . By simply switching the classifier, users can leverage the performance benefits of FastCAV
without changing the rest of their workflow. It inherits from the Default Classifier
using a FastCAVLinearModel
as its internal engine.
Constructor:
FastCAVClassifier()