From aa52b2622086e66e53fe975620290518dfec1952 Mon Sep 17 00:00:00 2001 From: Fanyu Meng Date: Mon, 21 Jul 2025 13:34:14 -0700 Subject: [PATCH] Allow more variable input to LLMAttributionResult (#1627) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1627 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863 --- captum/attr/_core/llm_attr.py | 117 +++++++++++++++++++++++++++++----- 1 file changed, 102 insertions(+), 15 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 90540b1bd..0e6ffd1dc 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -5,13 +5,14 @@ from abc import ABC from copy import copy from dataclasses import dataclass -from textwrap import shorten +from textwrap import dedent, shorten from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import torch from captum._utils.typing import TokenizerLike @@ -51,11 +52,92 @@ class LLMAttributionResult: It also provides utilities to help present and plot the result in different forms. """ - seq_attr: Tensor - token_attr: Optional[Tensor] input_tokens: List[str] output_tokens: List[str] - output_probs: Optional[Tensor] = None + # pyre-ignore[13]: initialized via a property setter + _seq_attr: Tensor + _token_attr: Optional[Tensor] = None + _output_probs: Optional[Tensor] = None + + def __init__( + self, + *, + input_tokens: List[str], + output_tokens: List[str], + seq_attr: npt.ArrayLike, + token_attr: Optional[npt.ArrayLike] = None, + output_probs: Optional[npt.ArrayLike] = None, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.seq_attr = seq_attr + self.token_attr = token_attr + self.output_probs = output_probs + + @property + def seq_attr(self) -> Tensor: + return self._seq_attr + + @seq_attr.setter + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: + if isinstance(seq_attr, Tensor): + self._seq_attr = seq_attr + else: + self._seq_attr = torch.tensor(seq_attr) + # IDEA: in the future we might want to support higher dim seq_attr + # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) + assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" + assert ( + len(self.input_tokens) == self._seq_attr.shape[0] + ), "seq_attr and input_tokens must have the same length" + + @property + def token_attr(self) -> Optional[Tensor]: + return self._token_attr + + @token_attr.setter + def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: + if token_attr is None: + self._token_attr = None + elif isinstance(token_attr, Tensor): + self._token_attr = token_attr + else: + self._token_attr = torch.tensor(token_attr) + + if self._token_attr is not None: + # IDEA: in the future we might want to support higher dim seq_attr + assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" + assert self._token_attr.shape == ( + len(self.output_tokens), + len(self.input_tokens), + ), dedent( + f"""\ + Expect token_attr to have shape + {len(self.output_tokens), len(self.input_tokens)}, + got {self._token_attr.shape} + """ + ) + + @property + def output_probs(self) -> Optional[Tensor]: + return self._output_probs + + @output_probs.setter + def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None: + if output_probs is None: + self._output_probs = None + elif isinstance(output_probs, Tensor): + self._output_probs = output_probs + else: + self._output_probs = torch.tensor(output_probs) + + if self._output_probs is not None: + assert ( + len(self._output_probs.shape) == 1 + ), "output_probs must be a 1D tensor" + assert ( + len(self.output_tokens) == self._output_probs.shape[0] + ), "seq_attr and input_tokens must have the same length" @property def seq_attr_dict(self) -> Dict[str, float]: @@ -125,10 +207,14 @@ def plot_token_attr( # Show all ticks and label them with the respective list entries. shortened_tokens = [ - shorten(t, width=50, placeholder="...") for t in self.input_tokens + shorten(repr(t)[1:-1], width=50, placeholder="...") + for t in self.input_tokens ] ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens) - ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) + ax.set_yticks( + np.arange(data.shape[0]), + labels=[repr(token)[1:-1] for token in self.output_tokens], + ) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) @@ -176,7 +262,8 @@ def plot_seq_attr( fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) shortened_tokens = [ - shorten(t, width=50, placeholder="...") for t in self.input_tokens + shorten(repr(t)[1:-1], width=50, placeholder="...") + for t in self.input_tokens ] ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) @@ -690,12 +777,12 @@ def attribute( attr = inp.format_attr(attr) return LLMAttributionResult( - attr[0], - ( + seq_attr=attr[0], + token_attr=( attr[1:] if self.include_per_token_attr else None ), # shape(n_output_token, n_input_features) - inp.values, - _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), + input_tokens=inp.values, + output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), ) def attribute_future(self) -> Callable[[], LLMAttributionResult]: @@ -830,10 +917,10 @@ def attribute( seq_attr = attr.sum(0) return LLMAttributionResult( - seq_attr, - attr, # shape(n_output_token, n_input_features) - inp.values, - _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), + seq_attr=seq_attr, + token_attr=attr, # shape(n_output_token, n_input_features) + input_tokens=inp.values, + output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), ) def attribute_future(self) -> Callable[[], LLMAttributionResult]: