Skip to content

Adding a Fast Implementation to Calculate Concept Activation Vectors for High-Dimensional Models to Captum #1619

Open
@asver12

Description

@asver12

Motivation

The computation of Concept Activation Vectors (CAVs) is a fundamental component of concept-based explanation methods, most notably Testing with Concept Activation Vectors (TCAV), which is already implemented in Captum. The current standard for computing CAVs in Captum relies on training a Support Vector Machines (SVMs) for each concept. This can present a significant challenge, especially for modern, high-dimensional models. To address this, we propose adding an efficient drop-in with FastCAV. We provide benchmarking and theoretical justification to support its effectiveness.

Captum FastCAV API Design

Background

FastCAV is a novel approach introduced at ICML 2025 based on insights into superposition, that accelerates the computation of CAVs. FastCAV defines the CAV as the vector from the global mean of all activations to the mean of the positive class activations. As such it acts as a drop-in replacement for the prevalent SVM for balanced classes. Additionally, we provide mathematical assumptions under which FastCAV is equivalent to an SVM.

Requires:

  • Model: The trained neural network model $f$ to be interpreted.
  • Layer: A specific layer $l$ within the model from which to extract activations.
  • Concept Dataset $D_c$: A set of input examples that visually represent the concept to be analyzed.
  • Control Dataset $D_r$: A set of random or control examples used as a baseline.

Pseudocode:

def fit(self, x: Tensor, y: Tensor) -> None:
    self.mean = x.mean(dim=0)
    self.coef_ = (x[y == self.classes_[-1]] -self.mean).mean(dim=0).unsqueeze(0)
    self.intercept_ = (-self.coef_ @ self.mean).unsqueeze(1)

Benchmarking

Performance benchmarking:

Given the example in the tutorial tutorials/TCAV_Image.ipynb we achieve an average speed-up to the DefaultClassifier of ~8.26 (SVM: 1.09 $\pm$ 0.39, FastCAV: 0.132 $\pm$ 0.025)
For bigger models this improves more drastically as shown in Table 1. of the paper.

Visual Benchmarking

Given the example in the tutorial tutorials/TCAV_Image.ipynb we get the following results for experimental_set_rand with FastCAV and SVM:
FastCAV:
Image

SVM:
Image

and for experimental_set_zig_dot:

FastCAV:
Image

SVM:
Image

Mathematical Assumptions:

Under the given assumptions FastCAV is equivalent to an SVM:

  • Gaussian Distribution: The activation vectors for both the random samples and the concept samples are assumed to follow independent multivariate Gaussian distributions.
  • Equal Mixture: The set of concept examples and the set of random examples are of equal size ($∣Dc​∣=∣Dr​∣$), resulting in a uniform mixture of the two Gaussian distributions.
  • Isotropic Covariance: The within-class covariance matrices are assumed to be isotropic, meaning they are proportional to the unit matrix. This is a critical assumption that makes the FastCAV solution equivalent to the solution of a Fisher discriminant analysis.
  • High-Dimensionality: The method is analyzed in the context of high-dimensional activation spaces, where the number of dimensions $d$ is significantly larger than the number of samples $n$ ($d \gg n$). In such spaces, the set of support vectors used by an SVM is likely to contain most of the training samples, making the SVM solution approximate the Fisher discriminant solution, and by extension, the FastCAV solution.

Proposed Captum API Design:

The integration of FastCAV into Captum is designed to be a high-performance drop-in to the default DefaultClassifier which utilizes an SVM. It is exposed to the user through the FastCAVClassifier class, which leverages a FastCAVLinearModel internally.
E.g. for the tutorial tutorials/TCAV_Image.ipynb:

fast_clf = classifier.FastCAVClassifier()
mytcav = TCAV(model=model,
              layers=layers,
              classifier=fast_clf,
              layer_attr_method = LayerIntegratedGradients(
                model, None, multiply_by_inputs=False))

FastCAV

This is a low-level utility class, similar in interface to scikit-learn classifiers. It contains the core logic for computing the CAV. It is not a torch.nn.Module and is not intended for direct use within most Captum workflows, but provides the fundamental algorithm.
Constructor:

FastCAV(**kwargs)

Argument Descriptions:

  • kwargs - The constructor currently accepts but ignores any keyword arguments kwargs to maintain a consistent interface with other classifiers.
    Methods:
    fit(x: Tensor, y: Tensor): Takes tensors x of activations and a tensor y of labels and computes the coef_ and intercept_.
  • x: A 2D tensor of shape (n_samples, n_features) containing the input data (e.g., model activations).
  • y: A 1D tensor of shape (n_samples,) containing binary labels (0 or 1).
    predict(x: Tensor) -> Tensor: Predicts class labels for new data points based on the fitted hyperplane.
    classes() -> Tensor: Returns the unique class labels the model was fitted on.

fastcav_train_linear_model

This function serves as the bridge between the FastCAV logic and Captum's LinearModel interface. It is designed to be used as the train_fn for a LinearModel. It orchestrates the process of fitting a FastCAV instance and then using its learned parameters to configure a LinearModel object
.
Signature:

fastcav_train_linear_model(model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], norm_input: bool = False, **fit_kwargs: Any) -> Dict[str, float]

Argument Descriptions:

  • model: The LinearModel instance to be configured.
  • dataloader: A DataLoader providing the training data (activations and labels). Right now this follows sklearn_train_linear_model and iterates through the entire dataloader to collect all data into memory.
  • construct_kwargs: Keyword arguments passed to the FastCAV constructor. As FastCAV does not accept kwargs this is to maintain a consistent interface.
  • norm_input: A boolean indicating whether to normalize the input data before fitting.
  • fit_kwargs: Additional keyword arguments for the fit method (currently unused by FastCAV).

FastCAVLinearModel

This class acts as a bridge, wrapping the FastCAV logic into the LinearModel interface that is standard within Captum's concept utilities. This allows it to be used by higher-level abstractions.
Constructor:

FastCAVLinearModel(**kwargs)
  • The constructor simply calls the parent LinearModel constructor, passing fastcav_train_linear_model as the train_fn. Any kwargs are stored and passed to the training function during the fit call.

FastCAVClassifier

This is the main user-facing class. It is a drop-in replacement for DefaultClassifier for users of high-level APIs like TCAV or different future concept based explanations like ACE . By simply switching the classifier, users can leverage the performance benefits of FastCAV without changing the rest of their workflow. It inherits from the Default Classifier using a FastCAVLinearModel as its internal engine.
Constructor:

FastCAVClassifier()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions