Skip to content

Allow more variable input to LLMAttributionResult #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 102 additions & 15 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from abc import ABC
from copy import copy
from dataclasses import dataclass
from textwrap import shorten
from textwrap import dedent, shorten
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import matplotlib.colors as mcolors

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt

import torch
from captum._utils.typing import TokenizerLike
Expand Down Expand Up @@ -51,11 +52,92 @@ class LLMAttributionResult:
It also provides utilities to help present and plot the result in different forms.
"""

seq_attr: Tensor
token_attr: Optional[Tensor]
input_tokens: List[str]
output_tokens: List[str]
output_probs: Optional[Tensor] = None
# pyre-ignore[13]: initialized via a property setter
_seq_attr: Tensor
_token_attr: Optional[Tensor] = None
_output_probs: Optional[Tensor] = None

def __init__(
self,
*,
input_tokens: List[str],
output_tokens: List[str],
seq_attr: npt.ArrayLike,
token_attr: Optional[npt.ArrayLike] = None,
output_probs: Optional[npt.ArrayLike] = None,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.seq_attr = seq_attr
self.token_attr = token_attr
self.output_probs = output_probs

@property
def seq_attr(self) -> Tensor:
return self._seq_attr

@seq_attr.setter
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
if isinstance(seq_attr, Tensor):
self._seq_attr = seq_attr
else:
self._seq_attr = torch.tensor(seq_attr)
# IDEA: in the future we might want to support higher dim seq_attr
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
assert (
len(self.input_tokens) == self._seq_attr.shape[0]
), "seq_attr and input_tokens must have the same length"

@property
def token_attr(self) -> Optional[Tensor]:
return self._token_attr

@token_attr.setter
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
if token_attr is None:
self._token_attr = None
elif isinstance(token_attr, Tensor):
self._token_attr = token_attr
else:
self._token_attr = torch.tensor(token_attr)

if self._token_attr is not None:
# IDEA: in the future we might want to support higher dim seq_attr
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
assert self._token_attr.shape == (
len(self.output_tokens),
len(self.input_tokens),
), dedent(
f"""\
Expect token_attr to have shape
{len(self.output_tokens), len(self.input_tokens)},
got {self._token_attr.shape}
"""
)

@property
def output_probs(self) -> Optional[Tensor]:
return self._output_probs

@output_probs.setter
def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None:
if output_probs is None:
self._output_probs = None
elif isinstance(output_probs, Tensor):
self._output_probs = output_probs
else:
self._output_probs = torch.tensor(output_probs)

if self._output_probs is not None:
assert (
len(self._output_probs.shape) == 1
), "output_probs must be a 1D tensor"
assert (
len(self.output_tokens) == self._output_probs.shape[0]
), "seq_attr and input_tokens must have the same length"

@property
def seq_attr_dict(self) -> Dict[str, float]:
Expand Down Expand Up @@ -125,10 +207,14 @@ def plot_token_attr(

# Show all ticks and label them with the respective list entries.
shortened_tokens = [
shorten(t, width=50, placeholder="...") for t in self.input_tokens
shorten(repr(t)[1:-1], width=50, placeholder="...")
for t in self.input_tokens
]
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
ax.set_yticks(
np.arange(data.shape[0]),
labels=[repr(token)[1:-1] for token in self.output_tokens],
)

# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
Expand Down Expand Up @@ -176,7 +262,8 @@ def plot_seq_attr(
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))

shortened_tokens = [
shorten(t, width=50, placeholder="...") for t in self.input_tokens
shorten(repr(t)[1:-1], width=50, placeholder="...")
for t in self.input_tokens
]
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)

Expand Down Expand Up @@ -690,12 +777,12 @@ def attribute(
attr = inp.format_attr(attr)

return LLMAttributionResult(
attr[0],
(
seq_attr=attr[0],
token_attr=(
attr[1:] if self.include_per_token_attr else None
), # shape(n_output_token, n_input_features)
inp.values,
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
input_tokens=inp.values,
output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

def attribute_future(self) -> Callable[[], LLMAttributionResult]:
Expand Down Expand Up @@ -830,10 +917,10 @@ def attribute(
seq_attr = attr.sum(0)

return LLMAttributionResult(
seq_attr,
attr, # shape(n_output_token, n_input_features)
inp.values,
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
seq_attr=seq_attr,
token_attr=attr, # shape(n_output_token, n_input_features)
input_tokens=inp.values,
output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

def attribute_future(self) -> Callable[[], LLMAttributionResult]:
Expand Down
Loading